Refactored to security group api to support projects
This commit is contained in:
@@ -649,8 +649,6 @@ class AuthManager(object):
|
||||
def delete_user(self, user):
|
||||
"""Deletes a user"""
|
||||
with self.driver() as drv:
|
||||
for security_group in db.security_group_get_by_user(context = {}, user_id=User.safe_id(user)):
|
||||
db.security_group_destroy({}, security_group.id)
|
||||
drv.delete_user(User.safe_id(user))
|
||||
|
||||
def generate_key_pair(self, user, key_name):
|
||||
|
||||
@@ -442,17 +442,28 @@ def volume_update(context, volume_id, values):
|
||||
"""
|
||||
return IMPL.volume_update(context, volume_id, values)
|
||||
|
||||
|
||||
####################
|
||||
|
||||
|
||||
def security_group_create(context, values):
|
||||
"""Create a new security group"""
|
||||
return IMPL.security_group_create(context, values)
|
||||
def security_group_get_all(context):
|
||||
"""Get all security groups"""
|
||||
return IMPL.security_group_get_all(context)
|
||||
|
||||
|
||||
def security_group_get_by_id(context, security_group_id):
|
||||
def security_group_get(context, security_group_id):
|
||||
"""Get security group by its internal id"""
|
||||
return IMPL.security_group_get_by_id(context, security_group_id)
|
||||
return IMPL.security_group_get(context, security_group_id)
|
||||
|
||||
|
||||
def security_group_get_by_name(context, project_id, group_name):
|
||||
"""Returns a security group with the specified name from a project"""
|
||||
return IMPL.securitygroup_get_by_name(context, project_id, group_name)
|
||||
|
||||
|
||||
def security_group_get_by_project(context, project_id):
|
||||
"""Get all security groups belonging to a project"""
|
||||
return IMPL.securitygroup_get_by_project(context, project_id)
|
||||
|
||||
|
||||
def security_group_get_by_instance(context, instance_id):
|
||||
@@ -460,15 +471,10 @@ def security_group_get_by_instance(context, instance_id):
|
||||
return IMPL.security_group_get_by_instance(context, instance_id)
|
||||
|
||||
|
||||
def security_group_get_by_user(context, user_id):
|
||||
"""Get security groups owned by the given user"""
|
||||
return IMPL.security_group_get_by_user(context, user_id)
|
||||
|
||||
|
||||
def security_group_get_by_user_and_name(context, user_id, name):
|
||||
"""Get user's named security group"""
|
||||
return IMPL.security_group_get_by_user_and_name(context, user_id, name)
|
||||
|
||||
def security_group_create(context, values):
|
||||
"""Create a new security group"""
|
||||
return IMPL.security_group_create(context, values)
|
||||
|
||||
|
||||
def security_group_destroy(context, security_group_id):
|
||||
"""Deletes a security group"""
|
||||
|
||||
@@ -616,20 +616,45 @@ def volume_update(_context, volume_id, values):
|
||||
###################
|
||||
|
||||
|
||||
def security_group_create(_context, values):
|
||||
security_group_ref = models.SecurityGroup()
|
||||
for (key, value) in values.iteritems():
|
||||
security_group_ref[key] = value
|
||||
security_group_ref.save()
|
||||
return security_group_ref
|
||||
def security_group_get_all(_context):
|
||||
session = get_session()
|
||||
return session.query(models.SecurityGroup
|
||||
).options(eagerload('rules')
|
||||
).filter_by(deleted=False
|
||||
).all()
|
||||
|
||||
|
||||
def security_group_get_by_id(_context, security_group_id):
|
||||
def security_group_get(_context, security_group_id):
|
||||
session = get_session()
|
||||
with session.begin():
|
||||
return session.query(models.SecurityGroup
|
||||
).options(eagerload('rules')
|
||||
).get(security_group_id)
|
||||
|
||||
|
||||
def securitygroup_get_by_name(context, project_id, group_name):
|
||||
session = get_session()
|
||||
group_ref = session.query(models.SecurityGroup
|
||||
).options(eagerload('rules')
|
||||
).get(security_group_id)
|
||||
).filter_by(project_id=project_id
|
||||
).filter_by(name=group_name
|
||||
).filter_by(deleted=False
|
||||
).first()
|
||||
if not group_ref:
|
||||
raise exception.NotFound(
|
||||
'No security group named %s for project: %s' \
|
||||
% (group_name, project_id))
|
||||
|
||||
return group_ref
|
||||
|
||||
|
||||
def securitygroup_get_by_project(_context, project_id):
|
||||
session = get_session()
|
||||
return session.query(models.SecurityGroup
|
||||
).options(eagerload('rules')
|
||||
).filter_by(project_id=project_id
|
||||
).filter_by(deleted=False
|
||||
).all()
|
||||
|
||||
|
||||
def security_group_get_by_instance(_context, instance_id):
|
||||
@@ -638,34 +663,27 @@ def security_group_get_by_instance(_context, instance_id):
|
||||
return session.query(models.Instance
|
||||
).get(instance_id
|
||||
).security_groups \
|
||||
.all()
|
||||
.filter_by(deleted=False
|
||||
).all()
|
||||
|
||||
|
||||
def security_group_get_by_user(_context, user_id):
|
||||
session = get_session()
|
||||
with session.begin():
|
||||
return session.query(models.SecurityGroup
|
||||
).filter_by(user_id=user_id
|
||||
).filter_by(deleted=False
|
||||
).options(eagerload('rules')
|
||||
).all()
|
||||
def security_group_create(_context, values):
|
||||
security_group_ref = models.SecurityGroup()
|
||||
for (key, value) in values.iteritems():
|
||||
security_group_ref[key] = value
|
||||
security_group_ref.save()
|
||||
return security_group_ref
|
||||
|
||||
def security_group_get_by_user_and_name(_context, user_id, name):
|
||||
session = get_session()
|
||||
with session.begin():
|
||||
return session.query(models.SecurityGroup
|
||||
).filter_by(user_id=user_id
|
||||
).filter_by(name=name
|
||||
).filter_by(deleted=False
|
||||
).options(eagerload('rules')
|
||||
).one()
|
||||
|
||||
def security_group_destroy(_context, security_group_id):
|
||||
session = get_session()
|
||||
with session.begin():
|
||||
security_group = session.query(models.SecurityGroup
|
||||
).get(security_group_id)
|
||||
security_group.delete(session=session)
|
||||
# TODO(vish): do we have to use sql here?
|
||||
session.execute('update security_group set deleted=1 where id=:id',
|
||||
{'id': security_group_id})
|
||||
session.execute('update security_group_rule set deleted=1 '
|
||||
'where group_id=:id',
|
||||
{'id': security_group_id})
|
||||
|
||||
|
||||
###################
|
||||
|
||||
@@ -306,26 +306,23 @@ class SecurityGroup(BASE, NovaBase):
|
||||
|
||||
class SecurityGroupIngressRule(BASE, NovaBase):
|
||||
"""Represents a rule in a security group"""
|
||||
__tablename__ = 'security_group_rules'
|
||||
__tablename__ = 'security_group_rule'
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
parent_group_id = Column(Integer, ForeignKey('security_group.id'))
|
||||
parent_group = relationship("SecurityGroup", backref="rules", foreign_keys=parent_group_id,
|
||||
primaryjoin=parent_group_id==SecurityGroup.id)
|
||||
group_id = Column(Integer, ForeignKey('security_group.id'))
|
||||
group = relationship("SecurityGroup", backref="rules",
|
||||
foreign_keys=group_id,
|
||||
primaryjoin=group_id==SecurityGroup.id)
|
||||
|
||||
protocol = Column(String(5)) # "tcp", "udp", or "icmp"
|
||||
from_port = Column(Integer)
|
||||
to_port = Column(Integer)
|
||||
cidr = Column(String(255))
|
||||
|
||||
# Note: This is not the parent SecurityGroup. It's SecurityGroup we're
|
||||
# granting access for.
|
||||
group_id = Column(Integer, ForeignKey('security_group.id'))
|
||||
source_group_id = Column(Integer, ForeignKey('security_group.id'))
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
return auth.manager.AuthManager().get_user(self.user_id)
|
||||
|
||||
cidr = Column(String(255))
|
||||
|
||||
class Network(BASE, NovaBase):
|
||||
"""Represents a network"""
|
||||
@@ -430,8 +427,9 @@ class FloatingIp(BASE, NovaBase):
|
||||
def register_models():
|
||||
"""Register Models and create metadata"""
|
||||
from sqlalchemy import create_engine
|
||||
models = (Service, Instance, Volume, ExportDevice,
|
||||
FixedIp, FloatingIp, Network, NetworkIndex) # , Image, Host)
|
||||
models = (Service, Instance, Volume, ExportDevice, FixedIp, FloatingIp,
|
||||
Network, NetworkIndex, SecurityGroup, SecurityGroupIngressRule)
|
||||
# , Image, Host
|
||||
engine = create_engine(FLAGS.sql_connection, echo=False)
|
||||
for model in models:
|
||||
model.metadata.create_all(engine)
|
||||
|
||||
@@ -216,7 +216,8 @@ class CloudController(object):
|
||||
@rbac.allow('all')
|
||||
def describe_security_groups(self, context, **kwargs):
|
||||
groups = []
|
||||
for group in db.security_group_get_by_user(context, context.user.id):
|
||||
for group in db.security_group_get_by_project(context,
|
||||
context.project.id):
|
||||
group_dict = {}
|
||||
group_dict['groupDescription'] = group.description
|
||||
group_dict['groupName'] = group.name
|
||||
@@ -229,10 +230,11 @@ class CloudController(object):
|
||||
rule_dict['toPort'] = rule.to_port
|
||||
rule_dict['groups'] = []
|
||||
rule_dict['ipRanges'] = []
|
||||
import pdb; pdb.set_trace()
|
||||
if rule.group_id:
|
||||
foreign_group = db.security_group_get_by_id({}, rule.group_id)
|
||||
rule_dict['groups'] += [ { 'groupName': foreign_group.name,
|
||||
'userId': foreign_group.user_id } ]
|
||||
source_group = db.security_group_get(context, rule.group_id)
|
||||
rule_dict['groups'] += [ { 'groupName': source_group.name,
|
||||
'userId': source_group.user_id } ]
|
||||
else:
|
||||
rule_dict['ipRanges'] += [ { 'cidrIp': rule.cidr } ]
|
||||
group_dict['ipPermissions'] += [ rule_dict ]
|
||||
@@ -258,23 +260,22 @@ class CloudController(object):
|
||||
user_id=None,
|
||||
source_security_group_name=None,
|
||||
source_security_group_owner_id=None):
|
||||
security_group = db.security_group_get_by_user_and_name(context,
|
||||
context.user.id,
|
||||
group_name)
|
||||
security_group = db.security_group_get_by_name(context,
|
||||
context.project.id,
|
||||
group_name)
|
||||
|
||||
criteria = {}
|
||||
|
||||
if source_security_group_name:
|
||||
if source_security_group_owner_id:
|
||||
other_user_id = source_security_group_owner_id
|
||||
else:
|
||||
other_user_id = context.user.id
|
||||
|
||||
foreign_security_group = \
|
||||
db.security_group_get_by_user_and_name(context,
|
||||
other_user_id,
|
||||
source_security_group_name)
|
||||
criteria['group_id'] = foreign_security_group.id
|
||||
source_project_id = self._get_source_project_id(context,
|
||||
source_security_group_owner_id)
|
||||
|
||||
source_security_group = \
|
||||
db.security_group_get_by_name(context,
|
||||
source_project_id,
|
||||
source_security_group_name)
|
||||
|
||||
criteria['group_id'] = source_security_group.id
|
||||
elif cidr_ip:
|
||||
criteria['cidr'] = cidr_ip
|
||||
else:
|
||||
@@ -303,22 +304,20 @@ class CloudController(object):
|
||||
ip_protocol=None, cidr_ip=None,
|
||||
source_security_group_name=None,
|
||||
source_security_group_owner_id=None):
|
||||
security_group = db.security_group_get_by_user_and_name(context,
|
||||
context.user.id,
|
||||
group_name)
|
||||
values = { 'parent_group_id' : security_group.id }
|
||||
security_group = db.security_group_get_by_name(context,
|
||||
context.project.id,
|
||||
group_name)
|
||||
values = { 'group_id' : security_group.id }
|
||||
|
||||
if source_security_group_name:
|
||||
if source_security_group_owner_id:
|
||||
other_user_id = source_security_group_owner_id
|
||||
else:
|
||||
other_user_id = context.user.id
|
||||
source_project_id = self._get_source_project_id(context,
|
||||
source_security_group_owner_id)
|
||||
|
||||
foreign_security_group = \
|
||||
db.security_group_get_by_user_and_name(context,
|
||||
other_user_id,
|
||||
source_security_group_name)
|
||||
values['group_id'] = foreign_security_group.id
|
||||
source_security_group = \
|
||||
db.security_group_get_by_name(context,
|
||||
source_project_id,
|
||||
source_security_group_name)
|
||||
values['source_group_id'] = source_security_group.id
|
||||
elif cidr_ip:
|
||||
values['cidr'] = cidr_ip
|
||||
else:
|
||||
@@ -336,18 +335,38 @@ class CloudController(object):
|
||||
|
||||
security_group_rule = db.security_group_rule_create(context, values)
|
||||
return True
|
||||
|
||||
def _get_source_project_id(self, context, source_security_group_owner_id):
|
||||
if source_security_group_owner_id:
|
||||
# Parse user:project for source group.
|
||||
source_parts = source_security_group_owner_id.split(':')
|
||||
|
||||
# If no project name specified, assume it's same as user name.
|
||||
# Since we're looking up by project name, the user name is not
|
||||
# used here. It's only read for EC2 API compatibility.
|
||||
if len(source_parts) == 2:
|
||||
source_project_id = parts[1]
|
||||
else:
|
||||
source_project_id = parts[0]
|
||||
else:
|
||||
source_project_id = context.project.id
|
||||
|
||||
return source_project_id
|
||||
|
||||
@rbac.allow('netadmin')
|
||||
def create_security_group(self, context, group_name, group_description):
|
||||
db.security_group_create(context,
|
||||
values = { 'user_id' : context.user.id,
|
||||
'project_id': context.project.id,
|
||||
'name': group_name,
|
||||
'description': group_description })
|
||||
return True
|
||||
|
||||
@rbac.allow('netadmin')
|
||||
def delete_security_group(self, context, group_name, **kwargs):
|
||||
security_group = db.security_group_get_by_user_and_name(context, context.user.id, group_name)
|
||||
security_group = db.security_group_get_by_name(context,
|
||||
context.project.id,
|
||||
group_name)
|
||||
db.security_group_destroy(context, security_group.id)
|
||||
return True
|
||||
|
||||
|
||||
@@ -304,6 +304,7 @@ class ApiEc2TestCase(test.BaseTestCase):
|
||||
# be good enough for that.
|
||||
for group in rv:
|
||||
if group.name == security_group_name:
|
||||
import pdb; pdb.set_trace()
|
||||
self.assertEquals(len(group.rules), 1)
|
||||
self.assertEquals(int(group.rules[0].from_port), 80)
|
||||
self.assertEquals(int(group.rules[0].to_port), 81)
|
||||
|
||||
@@ -86,6 +86,8 @@ class NWFilterTestCase(test.TrialTestCase):
|
||||
context.user = FakeContext()
|
||||
context.user.id = 'fake'
|
||||
context.user.is_superuser = lambda:True
|
||||
context.project = FakeContext()
|
||||
context.project.id = 'fake'
|
||||
cloud_controller.create_security_group(context, 'testgroup', 'test group description')
|
||||
cloud_controller.authorize_security_group_ingress(context, 'testgroup', from_port='80',
|
||||
to_port='81', ip_protocol='tcp',
|
||||
@@ -93,7 +95,7 @@ class NWFilterTestCase(test.TrialTestCase):
|
||||
|
||||
fw = libvirt_conn.NWFilterFirewall()
|
||||
|
||||
security_group = db.security_group_get_by_user_and_name({}, 'fake', 'testgroup')
|
||||
security_group = db.security_group_get_by_name({}, 'fake', 'testgroup')
|
||||
|
||||
xml = fw.security_group_to_nwfilter_xml(security_group.id)
|
||||
|
||||
|
||||
@@ -492,7 +492,7 @@ class NWFilterFirewall(object):
|
||||
</filter>'''
|
||||
|
||||
def security_group_to_nwfilter_xml(self, security_group_id):
|
||||
security_group = db.security_group_get_by_id({}, security_group_id)
|
||||
security_group = db.security_group_get({}, security_group_id)
|
||||
rule_xml = ""
|
||||
for rule in security_group.rules:
|
||||
rule_xml += "<rule action='allow' direction='in' priority='900'>"
|
||||
|
||||
Reference in New Issue
Block a user