Update hdsdiscovery and API

Support pica8 as vendor in hdsdiscovery
Update API and db model
Add more testing

Change-Id: I2a02f25d030967b0b83a16bb24741f2e7a48d46c
This commit is contained in:
grace.yu 2014-01-23 11:35:53 -08:00
parent 7c31e6a7fd
commit b49f814505
20 changed files with 1280 additions and 646 deletions

View File

@ -2,8 +2,9 @@
import logging
from compass.db import database
from compass.db.model import Switch, Machine
from compass.db.model import Switch, Machine, SwitchConfig
from compass.hdsdiscovery.hdmanager import HDManager
from sqlalchemy.exc import IntegrityError
def poll_switch(ip_addr, req_obj='mac', oper="SCAN"):
@ -24,6 +25,9 @@ def poll_switch(ip_addr, req_obj='mac', oper="SCAN"):
The function should be called inside database session scope.
"""
UNDERMONITORING = 'under_monitoring'
UNREACHABLE = 'unreachable'
if not ip_addr:
logging.error('No switch IP address is provided!')
return
@ -37,40 +41,65 @@ def poll_switch(ip_addr, req_obj='mac', oper="SCAN"):
return
credential = switch.credential
logging.error("pollswitch: credential %r", credential)
logging.info("pollswitch: credential %r", credential)
vendor = switch.vendor
prev_state = switch.state
hdmanager = HDManager()
if not vendor or not hdmanager.is_valid_vendor(ip_addr,
credential, vendor):
# No vendor found or vendor doesn't match queried switch.
logging.debug('no vendor or vendor had been changed for switch %s',
switch)
vendor = hdmanager.get_vendor(ip_addr, credential)
logging.debug('[pollswitch] credential %r', credential)
if not vendor:
logging.error('no vendor found or match switch %s', switch)
return
switch.vendor = vendor
vendor, vstate, err_msg = hdmanager.get_vendor(ip_addr, credential)
if not vendor:
switch.state = vstate
switch.err_msg = err_msg
logging.info("*****error_msg: %s****", switch.err_msg)
logging.error('no vendor found or match switch %s', switch)
return
switch.vendor = vendor
# Start to poll switch's mac address.....
logging.debug('hdmanager learn switch from %s %s %s %s %s',
ip_addr, credential, vendor, req_obj, oper)
results = hdmanager.learn(ip_addr, credential, vendor, req_obj, oper)
results = []
try:
results = hdmanager.learn(ip_addr, credential, vendor, req_obj, oper)
except:
switch.state = UNREACHABLE
switch.err_msg = "SNMP walk for querying MAC addresses timedout"
return
logging.info("pollswitch %s result: %s", switch, results)
if not results:
logging.error('no result learned from %s %s %s %s %s',
ip_addr, credential, vendor, req_obj, oper)
return
switch_id = switch.id
filter_ports = session.query(SwitchConfig.filter_port)\
.join(Switch)\
.filter(SwitchConfig.ip == Switch.ip)\
.filter(Switch.id == switch_id).all()
if filter_ports:
#Get all ports from tuples into list
filter_ports = [i[0] for i in filter_ports]
for entry in results:
mac = entry['mac']
machine = session.query(Machine).filter_by(mac=mac).first()
if not machine:
machine = Machine(mac=mac)
machine.port = entry['port']
machine.vlan = entry['vlan']
port = entry['port']
vlan = entry['vlan']
if port in filter_ports:
continue
try:
machine = Machine(mac=mac, port=port, vlan=vlan)
session.add(machine)
machine.switch = switch
except IntegrityError as e:
logging.debug('The record already exists in db! Error: %s', e)
continue
logging.debug('update switch %s state to under monitoring', switch)
switch.state = 'under_monitoring'
if prev_state != UNDERMONITORING:
#Update error message in db
switch.err_msg = ""
switch.state = UNDERMONITORING

View File

@ -46,7 +46,8 @@ class SwitchList(Resource):
limit = request.args.get(self.LIMIT, 0, type=int)
if switch_ips and switch_ip_network:
error_msg = 'switchIp and switchIpNetwork cannot be combined!'
error_msg = ("switchIp and switchIpNetwork cannot be "
"specified at the same time!")
return errors.handle_invalid_usage(
errors.UserInvalidUsage(error_msg))
@ -62,12 +63,16 @@ class SwitchList(Resource):
error_msg = 'SwitchIp format is incorrect!'
return errors.handle_invalid_usage(
errors.UserInvalidUsage(error_msg))
switch = session.query(ModelSwitch).filter_by(ip=ip_addr)\
.first()
if switch:
switches.append(switch)
logging.info('[SwitchList][get] ip %s', ip_addr)
if limit:
switches = switches[:limit]
elif switch_ip_network:
# query all switches which belong to the same network
if not util.is_valid_ipnetowrk(switch_ip_network):
@ -91,8 +96,14 @@ class SwitchList(Resource):
ip_network.prefixlen)
logging.info('ip_filter is %s', ip_filter)
result_set = session.query(ModelSwitch).filter(
ModelSwitch.ip.startswith(ip_filter)).all()
result_set = []
if limit:
result_set = session.query(ModelSwitch).filter(
ModelSwitch.ip.startswith(ip_filter)).limit(limit)\
.all()
else:
result_set = session.query(ModelSwitch).filter(
ModelSwitch.ip.startswith(ip_filter)).all()
for switch in result_set:
ip_addr = str(switch.ip)
@ -100,19 +111,19 @@ class SwitchList(Resource):
switches.append(switch)
logging.info('[SwitchList][get] ip %s', ip_addr)
if limit and len(switches) > limit:
switches = switches[:limit]
elif limit and not switches:
switches = session.query(ModelSwitch).limit(limit).all()
else:
switches = session.query(ModelSwitch).all()
if not switch_ips and not switch_ip_network:
if limit:
switches = session.query(ModelSwitch).limit(limit).all()
else:
switches = session.query(ModelSwitch).all()
for switch in switches:
switch_res = {}
switch_res['id'] = switch.id
switch_res['ip'] = switch.ip
switch_res['state'] = switch.state
if switch.state != 'under_monitoring':
switch_res['err_msg'] = switch.err_msg
switch_res['link'] = {
'rel': 'self',
'href': '/'.join((self.ENDPOINT, str(switch.id)))}
@ -203,6 +214,8 @@ class Switch(Resource):
switch_res['id'] = switch.id
switch_res['ip'] = switch.ip
switch_res['state'] = switch.state
if switch.state != 'under_monitoring':
switch_res['err_msg'] = switch.err_msg
switch_res['link'] = {
'rel': 'self',
'href': '/'.join((self.ENDPOINT, str(switch.id)))}
@ -218,6 +231,12 @@ class Switch(Resource):
:param switch_id: the unqiue identifier of the switch
"""
switch = None
credential = None
logging.debug('PUT a switch request from curl is %s', request.data)
ip_addr = None
switch_res = {}
with database.session() as session:
switch = session.query(ModelSwitch).filter_by(id=switch_id).first()
logging.info('PUT switch id is %s: %s', switch_id, switch)
@ -228,19 +247,15 @@ class Switch(Resource):
return errors.handle_not_exist(
errors.ObjectDoesNotExist(error_msg))
credential = None
logging.debug('PUT a switch request from curl is %s', request.data)
json_data = json.loads(request.data)
credential = json_data['switch']['credential']
json_data = json.loads(request.data)
credential = json_data['switch']['credential']
logging.info('PUT switch id=%s credential=%s(%s)',
switch_id, credential, type(credential))
logging.info('PUT switch id=%s credential=%s(%s)',
switch_id, credential, type(credential))
ip_addr = None
switch_res = {}
with database.session() as session:
switch = session.query(ModelSwitch).filter_by(id=switch_id).first()
switch.credential = credential
switch.state = "not_reached"
switch.state = "repolling"
switch.err_msg = ""
ip_addr = switch.ip
switch_res['id'] = switch.id
@ -276,17 +291,17 @@ class MachineList(Resource):
def get(self):
"""
Lists details of machines, optionally filtered by some conditions as
the following.
the following. According to SwitchConfig, machines with some ports will
be filtered.
:param switchId: the unique identifier of the switch
:param vladId: the vlan ID
:param port: the port number
:param limit: the number of records expected to return
"""
machines_result = []
switch_id = request.args.get(self.SWITCHID, type=int)
vlan = request.args.get(self.VLANID, type=int)
port = request.args.get(self.PORT, type=int)
port = request.args.get(self.PORT, None)
limit = request.args.get(self.LIMIT, 0, type=int)
with database.session() as session:
@ -299,7 +314,7 @@ class MachineList(Resource):
filter_clause.append('vlan=%d' % vlan)
if port:
filter_clause.append('port=%d' % port)
filter_clause.append('port=%s' % port)
if limit < 0:
error_msg = 'Limit cannot be less than 0!'
@ -308,21 +323,15 @@ class MachineList(Resource):
)
if filter_clause:
if limit:
machines = session.query(ModelMachine)\
.filter(and_(*filter_clause))\
.limit(limit).all()
else:
machines = session.query(ModelMachine)\
.filter(and_(*filter_clause)).all()
machines = session.query(ModelMachine)\
.filter(and_(*filter_clause)).all()
else:
if limit:
machines = session.query(ModelMachine).limit(limit).all()
else:
machines = session.query(ModelMachine).all()
machines = session.query(ModelMachine).all()
logging.info('all machines: %s', machines)
machines_result = []
for machine in machines:
if limit and len(machines_result) == limit:
break
machine_res = {}
machine_res['switch_ip'] = None if not machine.switch else \
machine.switch.ip
@ -627,11 +636,12 @@ def execute_cluster_action(cluster_id):
with database.session() as session:
failed_hosts = []
for host_id in hosts:
host = session.query(ModelClusterHost).filter_by(id=host_id)\
.first()
host = session.query(ModelClusterHost)\
.filter_by(id=host_id, cluster_id=cluster_id)\
.first()
if not host:
failed_hosts.append(host)
failed_hosts.append(host_id)
continue
host_res = {
@ -641,7 +651,7 @@ def execute_cluster_action(cluster_id):
removed_hosts.append(host_res)
if failed_hosts:
error_msg = 'Cluster hosts do not exist!'
error_msg = 'Hosts do not exist! Or not in this cluster'
value = {
"failedHosts": failed_hosts
}
@ -674,24 +684,43 @@ def execute_cluster_action(cluster_id):
session.flush()
return _add_hosts(cluster_id, hosts)
def _deploy(cluster_id):
def _deploy(cluster_id, hosts):
"""Deploy the cluster"""
deploy_hosts_urls = []
deploy_hosts_info = []
deploy_cluster_info = {}
with database.session() as session:
cluster_hosts_ids = session.query(ModelClusterHost.id)\
if not hosts:
# Deploy all hosts in the cluster
cluster_hosts = session.query(ModelClusterHost)\
.filter_by(cluster_id=cluster_id).all()
if not cluster_hosts_ids:
# No host belongs to this cluster
error_msg = ('Cannot find any host in cluster id=%s' %
cluster_id)
return errors.handle_not_exist(
errors.ObjectDoesNotExist(error_msg))
for elm in cluster_hosts_ids:
host_id = str(elm[0])
if not cluster_hosts:
# No host belongs to this cluster
error_msg = ('Cannot find any host in cluster id=%s' %
cluster_id)
return errors.handle_not_exist(
errors.ObjectDoesNotExist(error_msg))
for host in cluster_hosts:
if not host.mutable:
# The host is not allowed to modified
error_msg = ("The host id=%s is not allowed to be "
"modified now!") % host.id
return errors.UserInvalidUsage(
errors.UserInvalidUsage(error_msg))
hosts.append(host.id)
deploy_cluster_info["cluster_id"] = int(cluster_id)
deploy_cluster_info["url"] = '/clusters/%s/progress' % cluster_id
for host_id in hosts:
host_info = {}
progress_url = '/cluster_hosts/%s/progress' % host_id
deploy_hosts_urls.append(progress_url)
host_info["host_id"] = host_id
host_info["url"] = progress_url
deploy_hosts_info.append(host_info)
# Lock cluster hosts and its cluster
session.query(ModelClusterHost).filter_by(cluster_id=cluster_id)\
@ -699,10 +728,13 @@ def execute_cluster_action(cluster_id):
session.query(ModelCluster).filter_by(id=cluster_id)\
.update({'mutable': False})
celery.send_task("compass.tasks.trigger_install", (cluster_id,))
celery.send_task("compass.tasks.trigger_install", (cluster_id, hosts))
return util.make_json_response(
202, {"status": "OK",
"deployment": deploy_hosts_urls})
202, {"status": "accepted",
"deployment": {
"cluster": deploy_cluster_info,
"hosts": deploy_hosts_info
}})
request_data = None
with database.session() as session:
@ -714,7 +746,8 @@ def execute_cluster_action(cluster_id):
)
if not cluster.mutable:
# The cluster cannot be deploy again
error_msg = "The cluster id=%s cannot deploy again!" % cluster_id
error_msg = ("The cluster id=%s is not allowed to "
"modified or deployed!" % cluster_id)
return errors.handle_invalid_usage(
errors.UserInvalidUsage(error_msg))
@ -729,7 +762,7 @@ def execute_cluster_action(cluster_id):
return _remove_hosts(cluster_id, value)
elif 'deploy' in request_data:
return _deploy(cluster_id)
return _deploy(cluster_id, value)
elif 'replaceAllHosts' in request_data:
return _replace_all_hosts(cluster_id, value)
@ -791,6 +824,19 @@ class ClusterHostConfig(Resource):
#Valid if keywords in request_data are all correct
if 'hostname' in request_data:
hostname = request_data['hostname']
cluster_id = host.cluster_id
test_host = session.query(ModelClusterHost)\
.filter_by(cluster_id=cluster_id,
hostname=hostname).first()
if test_host:
error_msg = ("Hostname '%s' has been used for other host "
"in the cluster, cluster ID is %s!"
% hostname, cluster_id)
return errors.handle_invalid_usage(
errors.UserInvalidUsage(error_msg))
session.query(ModelClusterHost).filter_by(id=host_id)\
.update({"hostname": request_data['hostname']})
del request_data['hostname']

View File

@ -1,6 +1,4 @@
"""Utils for API usage"""
import logging
from flask import make_response
from flask.ext.restful import Api
@ -85,166 +83,172 @@ def is_valid_gateway(ip_addr):
return False
def _is_valid_nameservers(value):
if value:
nameservers = value.strip(",").split(",")
for elem in nameservers:
if not is_valid_ip(elem):
return False
else:
return False
return True
def is_valid_security_config(config):
"""Valid the format of security section in config"""
outer_format = {
"server_credentials": {}, "service_credentials": {},
"console_credentials": {}
}
inner_format = {
"username": {}, "password": {}
}
valid_outter, err = is_valid_keys(outer_format, config, "Security")
if not valid_outter:
return (False, err)
security_keys = ['server_credentials', 'service_credentials',
'console_credentials']
fields = ['username', 'password']
logging.debug('config: %s', config)
for key in security_keys:
try:
content = config[key]
except KeyError:
error_msg = "Missing '%s' in security config!" % key
logging.error(error_msg)
raise KeyError(error_msg)
for key in config:
content = config[key]
valid_inner, err = is_valid_keys(inner_format, content, key)
if not valid_inner:
return (False, err)
for k in fields:
try:
value = content[k]
if not value:
return False, '%s in %s cannot be null!' % (k, key)
except KeyError:
error_msg = ("Missing '%s' in '%s' section of security config"
% (k, key))
logging.error(error_msg)
raise KeyError(error_msg)
return True, 'valid!'
for sub_key in content:
if not content[sub_key]:
return (False, ("The value of %s in %s in security config "
"cannot be None!") % (sub_key, key))
return (True, '')
def is_valid_networking_config(config):
"""Valid the format of networking config"""
networking = ['interfaces', 'global']
def _is_valid_interfaces_config(interfaces_config):
"""Valid the format of interfaces section in config"""
interfaces_section = {
"management": {}, "tenant": {}, "public": {}, "storage": {}
}
section = {
"ip_start": {"req": 1, "validator": is_valid_ip},
"ip_end": {"req": 1, "validator": is_valid_ip},
"netmask": {"req": 1, "validator": is_valid_netmask},
"gateway": {"req": 0, "validator": is_valid_gateway},
"nic": {},
"promisc": {}
}
expected_keys = ['management', 'tenant', 'public', 'storage']
required_fields = ['nic', 'promisc']
normal_fields = ['ip_start', 'ip_end', 'netmask']
other_fields = ['gateway', 'vlan']
# Check if interfaces outer layer keywords
is_valid_outer, err = is_valid_keys(interfaces_section,
interfaces_config, "interfaces")
if not is_valid_outer:
return (False, err)
interfaces_keys = interfaces_config.keys()
for key in expected_keys:
if key not in interfaces_keys:
error_msg = "Missing '%s' in interfaces config!" % key
return False, error_msg
promisc_nics = []
nonpromisc_nics = []
for key in interfaces_config:
content = interfaces_config[key]
for field in required_fields:
if field not in content:
error_msg = "Keyword '%s' in interface %s cannot be None!"\
% (field, key)
return False, error_msg
is_valid_inner, err = is_valid_keys(section, content, key)
if not is_valid_inner:
return (False, err)
value = content[field]
if value is None:
error_msg = ("The value of '%s' in '%s' "
'config cannot be None!' %
(field, key))
return False, error_msg
if content["promisc"] not in [0, 1]:
return (False, ("The value of Promisc in %s section of "
"interfaces can only be either 0 or 1!") % key)
if not content["nic"]:
return (False, ("The NIC in %s cannot be None!") % key)
if field == 'promisc':
valid_values = [0, 1]
if int(value) not in valid_values:
return (
False,
('The value of Promisc for interface %s can '
'only be 0/1.bit_length' % key)
)
if content["promisc"]:
if content["nic"] not in nonpromisc_nics:
promisc_nics.append(content["nic"])
continue
else:
return (False,
("The NIC in %s cannot be assigned in promisc "
"and nonpromisc mode at the same time!" % key))
else:
if content["nic"] not in promisc_nics:
nonpromisc_nics.append(content["nic"])
else:
return (False,
("The NIC in %s cannot be assigned in promisc "
"and nonpromisc mode at the same time!" % key))
elif field == 'nic':
if not value.startswith('eth'):
return (
False,
('The value of nic for interface %s should start'
'with eth' % key)
)
# Validate other keywords in the section
for sub_key in content:
if sub_key == "promisc" or sub_key == "nic":
continue
value = content[sub_key]
is_required = section[sub_key]["req"]
validator = section[sub_key]["validator"]
if value:
if validator and not validator(value):
error_msg = "The format of %s in %s is invalid!" % \
(sub_key, key)
return (False, error_msg)
if not content['promisc']:
for field in normal_fields:
value = content[field]
if field == 'netmask' and not is_valid_netmask(value):
return (False, "Invalid netmask format for interface "
" %s: '%s'!" % (key, value))
elif not is_valid_ip(value):
return (False,
"Invalid Ip format for interface %s: '%s'"
% (key, value))
elif is_required:
return (False,
("%s in %s section in interfaces of networking "
"config cannot be None!") % (sub_key, key))
for field in other_fields:
if field in content and field == 'gateway':
value = content[field]
if value and not is_valid_gateway(value):
return False, "Invalid gateway format '%s'" % value
return True, 'Valid!'
return (True, '')
def _is_valid_global_config(global_config):
"""Valid the format of 'global' section in config"""
global_section = {
"nameservers": {"req": 1, "validator": _is_valid_nameservers},
"search_path": {"req": 1, "validator": ""},
"gateway": {"req": 1, "validator": is_valid_gateway},
"proxy": {"req": 0, "validator": ""},
"ntp_server": {"req": 0, "validator": ""}
}
is_valid_format, err = is_valid_keys(global_section, global_config,
"global")
if not is_valid_format:
return (False, err)
required_fields = ['nameservers', 'search_path', 'gateway']
global_keys = global_config.keys()
for key in required_fields:
if key not in global_keys:
error_msg = ("Missing %s in global config of networking config"
% key)
return False, error_msg
for key in global_section:
value = global_config[key]
if not value:
error_msg = ("Value of %s in global config cannot be None!" %
key)
return False, error_msg
is_required = global_section[key]["req"]
validator = global_section[key]["validator"]
if key == 'nameservers':
nameservers = [nameserver for nameserver in value.split(',')
if nameserver]
for nameserver in nameservers:
if not is_valid_ip(nameserver):
return (
False,
"The nameserver format is invalid! '%s'" % value
)
if value:
if validator and not validator(value):
return (False, ("The format of %s in global section of "
"networking config is invalid!") % key)
elif is_required:
return (False, ("The value of %s in global section of "
"netowrking config cannot be None!") % key)
elif key == 'gateway' and not is_valid_gateway(value):
return False, "The gateway format is invalid! '%s'" % value
return (True, '')
return True, 'Valid!'
networking_config = {
"interfaces": _is_valid_interfaces_config,
"global": _is_valid_global_config
}
#networking_keys = networking.keys()
is_valid = False
msg = None
for nkey in networking:
if nkey in config:
content = config[nkey]
valid_format, err = is_valid_keys(networking_config, config, "networking")
if not valid_format:
return (False, err)
if nkey == 'interfaces':
is_valid, msg = _is_valid_interfaces_config(content)
elif nkey == 'global':
is_valid, msg = _is_valid_global_config(content)
for key in networking_config:
validator = networking_config[key]
is_valid, err = validator(config[key])
if not is_valid:
return (False, err)
if not is_valid:
return is_valid, msg
else:
error_msg = "Missing '%s' in networking config!" % nkey
return False, error_msg
return True, 'valid!'
return (True, '')
def is_valid_partition_config(config):
"""Valid the configuration format"""
if not config:
return False, '%s in partition cannot be null!' % config
return (False, '%s in partition cannot be null!' % config)
return True, 'valid!'
return (True, '')
def valid_host_config(config):
@ -308,3 +312,41 @@ def update_dict_value(searchkey, newvalue, dictionary):
update_dict_value(searchkey, newvalue, dictionary[key])
else:
continue
def is_valid_keys(expected, input_dict, section=""):
excepted_keys = set(expected.keys())
input_keys = set(input_dict.keys())
if excepted_keys != input_keys:
invalid_keys = list(excepted_keys - input_keys) if \
len(excepted_keys) > len(input_keys) else\
list(input_keys - excepted_keys)
error_msg = ("Invalid or missing keywords in the %s "
"section of networking config. Please check these "
"keywords %s") % (section, invalid_keys)
return (False, error_msg)
return (True, "")
def is_same_dict_keys(expected_dict, config_dict):
if not expected_dict or not config_dict:
return (False, "The Config cannot be None!")
if expected_dict.viewkeys() == config_dict.viewkeys():
for expected_key, config_key in zip(expected_dict, config_dict):
if isinstance(expected_dict[expected_key], str):
return (True, "")
is_same, err = is_same_dict_keys(expected_dict[expected_key],
config_dict[config_key])
if not is_same:
return (False, err)
return (True, "")
if len(expected_dict) >= len(config_dict):
invalid_list = list(expected_dict.viewkeys() - config_dict.viewkeys())
else:
invalid_list = list(config_dict.viewkeys() - expected_dict.viewkeys())
return (False, "Invalid key(s) %r in the config" % invalid_list)

View File

@ -8,13 +8,23 @@ from sqlalchemy import Float, Enum, DateTime, ForeignKey, Text, Boolean
from sqlalchemy import UniqueConstraint
from sqlalchemy.orm import relationship, backref
from sqlalchemy.ext.declarative import declarative_base
from compass.utils import util
BASE = declarative_base()
class SwitchConfig(BASE):
__tablename__ = 'switch_config'
id = Column(Integer, primary_key=True)
ip = Column(String(80), ForeignKey("switch.ip"))
filter_port = Column(String(16))
__table_args__ = (UniqueConstraint('ip', 'filter_port', name='filter1'), )
def __init__(self, **kwargs):
super(SwitchConfig, self).__init__(**kwargs)
class Switch(BASE):
"""Switch table.
@ -26,6 +36,7 @@ class Switch(BASE):
:param state: Enum.'not_reached': polling switch fails or not complete to
learn all MAC addresses of devices connected to the switch;
'under_monitoring': successfully learn all MAC addresses.
:param err_msg: Error message when polling switch failed.
:param machines: refer to list of Machine connected to the switch.
"""
__tablename__ = 'switch'
@ -34,11 +45,12 @@ class Switch(BASE):
ip = Column(String(80), unique=True)
credential_data = Column(Text)
vendor_info = Column(String(256), nullable=True)
state = Column(Enum('not_reached', 'under_monitoring',
name='switch_state'))
state = Column(Enum('initialized', 'unreachable', 'notsupported',
'repolling', 'under_monitoring', name='switch_state'))
err_msg = Column(Text)
def __init__(self, **kwargs):
self.state = 'not_reached'
self.state = 'initialized'
super(Switch, self).__init__(**kwargs)
def __repr__(self):
@ -64,8 +76,6 @@ class Switch(BASE):
if self.credential_data:
try:
credential = json.loads(self.credential_data)
credential = dict(
[(str(k).title(), str(v)) for k, v in credential.items()])
return credential
except Exception as error:
logging.error('failed to load credential data %s: %s',
@ -115,14 +125,16 @@ class Machine(BASE):
__tablename__ = 'machine'
id = Column(Integer, primary_key=True)
mac = Column(String(24), unique=True)
port = Column(Integer)
mac = Column(String(24))
port = Column(String(16))
vlan = Column(Integer)
update_timestamp = Column(DateTime, default=datetime.now,
onupdate=datetime.now)
switch_id = Column(Integer, ForeignKey('switch.id',
onupdate='CASCADE',
ondelete='SET NULL'))
__table_args__ = (UniqueConstraint('mac', 'vlan', 'switch_id',
name='unique_1'), )
switch = relationship('Switch', backref=backref('machines',
lazy='dynamic'))
@ -361,7 +373,7 @@ class Cluster(BASE):
util.merge_dict(config, {'networking': self.networking})
util.merge_dict(config, {'partition': self.partition})
util.merge_dict(config, {'clusterid': self.id,
'clustername': self.name})
'clustername': self.name})
return config
@config.setter
@ -405,7 +417,7 @@ class ClusterHost(BASE):
machine_id = Column(Integer, ForeignKey('machine.id',
onupdate='CASCADE',
ondelete='CASCADE'),
nullable=True, unique=True)
nullable=True)
cluster_id = Column(Integer, ForeignKey('cluster.id',
onupdate='CASCADE',
@ -413,10 +425,10 @@ class ClusterHost(BASE):
nullable=True)
hostname = Column(String)
UniqueConstraint('cluster_id', 'hostname', name='unique_1')
config_data = Column(Text)
mutable = Column(Boolean, default=True)
__table_args__ = (UniqueConstraint('cluster_id', 'hostname',
name='unique_host'),)
cluster = relationship("Cluster", backref=backref('hosts', lazy='dynamic'))
machine = relationship("Machine", backref=backref('host', uselist=False))

View File

@ -1,7 +1,12 @@
"""
Base class extended by specific vendor in vendors directory.
a vendor need to implment abstract methods of base class.
A vendor needs to implement abstract methods of base class.
"""
import re
import logging
from compass.hdsdiscovery import utils
from compass.hdsdiscovery.error import TimeoutError
class BaseVendor(object):
@ -14,6 +19,23 @@ class BaseVendor(object):
raise NotImplementedError
class BaseSnmpVendor(BaseVendor):
"""Base SNMP-based vendor plugin. It uses MIB-II sysDescr value
to determine the vendor of the switch. """
def __init__(self, matched_names):
self._matched_names = matched_names
def is_this_vendor(self, host, credential, sys_info):
if utils.is_valid_snmp_v2_credential(credential) and sys_info:
for name in self._matched_names:
if re.search(r"\b" + re.escape(name) + r"\b", sys_info,
re.IGNORECASE):
return True
return False
class BasePlugin(object):
"""Extended by vendor's plugin, which processes request and
retrieve info directly from the switch.
@ -40,3 +62,89 @@ class BasePlugin(object):
def get(self, *args, **kwargs):
"""Get one record from a host"""
pass
class BaseSnmpMacPlugin(BasePlugin):
def __init__(self, host, credential, oid='BRIDGE-MIB::dot1dTpFdbPort',
vlan_oid='Q-BRIDGE-MIB::dot1qPvid'):
self.host = host
self.credential = credential
self.oid = oid
self.port_oid = 'ifName'
self.vlan_oid = vlan_oid
def process_data(self, oper='SCAN'):
func_name = oper.lower()
return getattr(self, func_name)()
def scan(self):
results = None
try:
results = utils.snmpwalk_by_cl(self.host, self.credential,
self.oid)
except TimeoutError as e:
logging.debug("PluginMac:scan snmpwalk_by_cl failed: %s",
e.message)
return None
mac_list = []
for entity in results:
ifIndex = entity['value']
if entity and int(ifIndex):
tmp = {}
mac_numbers = entity['iid'].split('.')
tmp['mac'] = self.get_mac_address(mac_numbers)
tmp['port'] = self.get_port(ifIndex)
tmp['vlan'] = self.get_vlan_id(ifIndex)
mac_list.append(tmp)
return mac_list
def get_vlan_id(self, port):
"""Get vlan Id"""
if not port:
return None
oid = '.'.join((self.vlan_oid, port))
vlan_id = None
result = None
try:
result = utils.snmpget_by_cl(self.host, self.credential, oid)
except TimeoutError as e:
logging.debug("[PluginMac:get_vlan_id snmpget_by_cl failed: %s]",
e.message)
return None
vlan_id = result.split()[-1]
return vlan_id
def get_port(self, if_index):
"""Get port number"""
if_name = '.'.join((self.port_oid, if_index))
result = None
try:
result = utils.snmpget_by_cl(self.host, self.credential, if_name)
except TimeoutError as e:
logging.debug("[PluginMac:get_port snmpget_by_cl failed: %s]",
e.message)
return None
# A result may be like "Value: FasterEthernet1/2/34
port = result.split()[-1].split('/')[-1]
return port
def convert_to_hex(self, value):
"""Convert the integer from decimal to hex"""
return "%0.2x" % int(value)
def get_mac_address(self, mac_numbers):
"""Assemble mac address from the list"""
if len(mac_numbers) != 6:
logging.error("[PluginMac:get_mac_address] MAC address must be "
"6 digitals")
return None
mac_in_hex = [self.convert_to_hex(num) for num in mac_numbers]
return ":".join(mac_in_hex)

View File

@ -0,0 +1,9 @@
"""hdsdiscovery module errors
"""
class TimeoutError(Exception):
def __init__(self, message):
self.message = message
def __str__(self):
return repr(self.message)

View File

@ -4,6 +4,10 @@ import re
import logging
from compass.hdsdiscovery import utils
from compass.hdsdiscovery.error import TimeoutError
UNREACHABLE = 'unreachable'
NOTSUPPORTED = 'notsupported'
class HDManager:
@ -13,6 +17,7 @@ class HDManager:
base_dir = os.path.dirname(os.path.realpath(__file__))
self.vendors_dir = os.path.join(base_dir, 'vendors')
self.vendor_plugins_dir = os.path.join(self.vendors_dir, '?/plugins')
self.snmp_sysdescr = 'sysDescr.0'
def learn(self, host, credential, vendor, req_obj, oper="SCAN", **kwargs):
"""Insert/update record of switch_info. Get expected results from
@ -50,38 +55,87 @@ class HDManager:
logging.error('no such directory: %s', vendor_dir)
return False
vendor_instance = utils.load_module(vendor, vendor_dir)
#TODO add more code to catch excpetion or unexpected state
if not vendor_instance:
# Cannot found the vendor in the directory!
logging.error('no vendor instance %s load from %s',
vendor, vendor_dir)
sys_info, err = self.get_sys_info(host, credential)
if not sys_info:
logging.debug("[hdsdiscovery][hdmanager][is_valid_vendor]"
"failded to get sys information: %s", err)
return False
return vendor_instance.is_this_vendor(host, credential)
instance = utils.load_module(vendor, vendor_dir)
if not instance:
logging.debug("[hdsdiscovery][hdmanager][is_valid_vendor]"
"No such vendor found!")
return False
if instance.is_this_vendor(host, credential, sys_info):
logging.info("[hdsdiscovery][hdmanager][is_valid_vendor]"
"vendor %s is correct!", vendor)
return True
return False
def get_vendor(self, host, credential):
""" Check and get vendor of the switch.
:param host: switch ip:
:param credential: credential to access switch
:return a tuple (vendor, switch_state, error)
"""
# TODO(grace): Why do we need to have valid IP?
# a hostname should also work.
if not utils.valid_ip_format(host):
logging.error("host '%s' is not valid IP address!" % host)
return (None, "", "Invalid IP address %s!" % host)
if not utils.is_valid_snmp_v2_credential(credential):
logging.debug("******The credential %s of host %s cannot "
"be used for either SNMP v2 or SSH*****",
credential, host)
return (None, "", "Invalid credential")
sys_info, err = self.get_sys_info(host, credential)
if not sys_info:
return (None, UNREACHABLE, err)
# List all vendors in vendors directory -- a directory but hidden
# under ../vendors
all_vendors = sorted(o for o in os.listdir(self.vendors_dir)
all_vendors = [o for o in os.listdir(self.vendors_dir)
if os.path.isdir(os.path.join(self.vendors_dir, o))
and re.match(r'^[^\.]', o))
and re.match(r'^[^\.]', o)]
logging.debug("[get_vendor]: %s ", all_vendors)
logging.debug("[get_vendor][available vendors]: %s ", all_vendors)
logging.debug("[get_vendor] System Information is [%s]" % sys_info)
# TODO(grace): should not conver to lower. The vendor impl can choose
# to do case-insensitive match
# sys_info = sys_info.lower()
vendor = None
for vname in all_vendors:
vpath = os.path.join(self.vendors_dir, vname)
instance = utils.load_module(vname, vpath)
#TODO add more code to catch excpetion or unexpected state
if not instance:
logging.error('no instance %s load from %s', vname, vpath)
continue
if instance.is_this_vendor(host, credential):
return vname
if instance.is_this_vendor(host, credential, sys_info):
logging.info("[get_vendor]****Found vendor '%s'****" % vname)
vendor = vname
break
return None
if not vendor:
logging.debug("[get_vendor] No vendor found! <==================")
return (None, NOTSUPPORTED, "Not supported switch vendor!")
return (vendor, "Found", "")
def get_sys_info(self, host, credential):
sys_info = None
try:
sys_info = utils.snmpget_by_cl(host,
credential,
self.snmp_sysdescr)
except TimeoutError as e:
return (None, e.message)
return (sys_info, "")

View File

@ -4,6 +4,9 @@
import imp
import re
import logging
import subprocess
from compass.hdsdiscovery.error import TimeoutError
def load_module(mod_name, path, host=None, credential=None):
@ -42,11 +45,17 @@ def ssh_remote_execute(host, username, password, cmd, *args):
"""
try:
import paramiko
if not cmd:
logging.error("[hdsdiscovery][utils][ssh_remote_execute] command"
"is None! Failed!")
return None
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(host, username=username, password=password)
client.connect(host, username=username, password=password, timeout=15)
stdin, stdout, stderr = client.exec_command(cmd)
return stdout.readlines()
result = stdout.readlines()
return result
except ImportError as exc:
logging.error("[hdsdiscovery][utils][ssh_remote_execute] failed to"
@ -61,6 +70,9 @@ def ssh_remote_execute(host, username, password, cmd, *args):
return None
finally:
stdin.close()
stdout.close()
stderr.close()
client.close()
@ -79,17 +91,18 @@ def valid_ip_format(ip_address):
# Implement snmpwalk and snmpget funtionality
# The structure of returned dictionary will by tag/iid/value/type
#################################################################
AUTH_VERSIONS = {'v1': 1,
'v2c': 2,
'v3': 3}
AUTH_VERSIONS = {'1': 1,
'2c': 2,
'3': 3}
def snmp_walk(host, credential, *args):
def snmp_walk(host, credential, *args, **kwargs):
"""Impelmentation of snmpwalk functionality
:param host: switch ip
:param credential: credential to access switch
:param args: OIDs
:param kwargs: key-value pairs
"""
try:
import netsnmp
@ -98,14 +111,14 @@ def snmp_walk(host, credential, *args):
logging.error("Module 'netsnmp' do not exist! Please install it first")
return None
if 'Version' not in credential or 'Community' not in credential:
logging.error("[utils] missing 'Version' and 'Community' in %s",
if 'version' not in credential or 'community' not in credential:
logging.error("[utils] missing 'version' and 'community' in %s",
credential)
return None
if credential['Version'] in AUTH_VERSIONS:
version = AUTH_VERSIONS[credential['Version']]
credential['Version'] = version
version = None
if credential['version'] in AUTH_VERSIONS:
version = AUTH_VERSIONS[credential['version']]
varbind_list = []
for arg in args:
@ -114,9 +127,17 @@ def snmp_walk(host, credential, *args):
var_list = netsnmp.VarList(*varbind_list)
res = netsnmp.snmpwalk(var_list, DestHost=host, **credential)
netsnmp.snmpwalk(var_list,
DestHost=host,
Version=version,
Community=credential['community'],
**kwargs)
result = []
if not var_list:
logging.error("[hsdiscovery][utils][snmp_walk] retrived no record!")
return result
for var in var_list:
response = {}
response['elem_name'] = var.tag
@ -128,7 +149,7 @@ def snmp_walk(host, credential, *args):
return result
def snmp_get(host, credential, object_type):
def snmp_get(host, credential, object_type, **kwargs):
"""Impelmentation of snmp get functionality
:param object_type: mib object
@ -142,19 +163,100 @@ def snmp_get(host, credential, object_type):
logging.error("Module 'netsnmp' do not exist! Please install it first")
return None
if 'Version' not in credential or 'Community' not in credential:
if 'version' not in credential or 'community' not in credential:
logging.error('[uitls][snmp_get] missing keywords in %s for %s',
credential, host)
return None
if credential['Version'] in AUTH_VERSIONS:
version = AUTH_VERSIONS[credential['Version']]
credential['Version'] = version
version = None
if credential['version'] in AUTH_VERSIONS:
version = AUTH_VERSIONS[credential['version']]
varbind = netsnmp.Varbind(object_type)
res = netsnmp.snmpget(varbind, DestHost=host, **credential)
if not res:
logging.error('no result found for %s %s', host, credential)
res = netsnmp.snmpget(varbind,
DestHost=host,
Version=version,
Community=credential['community'],
**kwargs)
if res and res[0]:
return res[0]
logging.info('no result found for %s %s', host, credential)
return None
SSH_CREDENTIALS = {"username": "", "password": ""}
SNMP_V2_CREDENTIALS = {"version": "", "community": ""}
def is_valid_snmp_v2_credential(credential):
if credential.keys() != SNMP_V2_CREDENTIALS.keys():
return False
if credential['version'] != '2c':
logging.error("The value of version in credential is not '2c'!")
return False
return True
def is_valid_ssh_credential(credential):
if credential.keys() != SSH_CREDENTIALS.keys():
return False
return True
def snmpget_by_cl(host, credential, oid, timeout=8, retries=3):
if not is_valid_snmp_v2_credential(credential):
logging.error("[utils][snmpget_by_cl] Credential %s cannot be used "
"for SNMP request!" % credential)
return None
return res[0]
version = credential['version']
community = credential['community']
cl = ("snmpget -v %s -c %s -Ob -r %s -t %s %s %s"
% (version, community, retries, timeout, host, oid))
output = None
sub_p = subprocess.Popen(cl, shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
output, err = sub_p.communicate()
if err:
logging.error("[snmpget_by_cl] %s", err)
raise TimeoutError(err.strip('\n'))
return output.strip('\n')
def snmpwalk_by_cl(host, credential, oid, timeout=5, retries=3):
if not is_valid_snmp_v2_credential(credential):
logging.error("[utils][snmpwalk_by_cl] Credential %s cannot be used "
"for SNMP request!" % credential)
return None
version = credential['version']
community = credential['community']
cl = ("snmpwalk -v %s -c %s -Cc -r %s -t %s -Ob %s %s"
% (version, community, retries, timeout, host, oid))
output = []
sub_p = subprocess.Popen(cl, shell=True, stdout=subprocess.PIPE)
output, err = sub_p.communicate()
if err:
logging.debug("[snmpwalk_by_cl] %s ", err)
raise TimeoutError(err)
result = []
if not output:
return result
output = output.split('\n')
for line in output:
if not line:
continue
temp = {}
arr = line.split(" ")
temp['iid'] = arr[0].split('.', 1)[-1]
temp['value'] = arr[-1]
result.append(temp)
return result

View File

@ -1,51 +1,19 @@
"""Vendor: HP"""
import re
import logging
from compass.hdsdiscovery import base
from compass.hdsdiscovery import utils
#Vendor_loader will load vendor instance by CLASS_NAME
CLASS_NAME = 'Hp'
class Hp(base.BaseVendor):
class Hp(base.BaseSnmpVendor):
"""Hp switch object"""
def __init__(self):
# the name of switch model belonging to Hewlett-Packard (HP) vendor
base.BaseSnmpVendor.__init__(self, ['hp', 'procurve'])
self.names = ['hp', 'procurve']
def is_this_vendor(self, host, credential):
"""
Determine if the hostname is accociated witH this vendor.
This example will use snmp sysDescr OID ,regex to extract
the vendor's name ,and then compare with self.name variable.
:param host: switch's IP address
:param credential: credential to access switch
"""
if "Version" not in credential or "Community" not in credential:
# The format of credential is incompatible with this vendor
err_msg = "[Hp]Missing keyword 'Version' or 'Community' in %r"
logging.error(err_msg, credential)
return False
sys_info = utils.snmp_get(host, credential, "sysDescr.0")
if not sys_info:
logging.info("Dismatch vendor information")
return False
sys_info = sys_info.lower()
for name in self.names:
if re.search(r"\b" + re.escape(name) + r"\b", sys_info):
return True
return False
@property
def name(self):
"""Get 'name' proptery"""
return 'hp'
return self.names[0]

View File

@ -1,79 +1,9 @@
"""HP Switch Mac module"""
from compass.hdsdiscovery import utils
from compass.hdsdiscovery import base
from compass.hdsdiscovery.base import BaseSnmpMacPlugin
CLASS_NAME = 'Mac'
class Mac(base.BasePlugin):
class Mac(BaseSnmpMacPlugin):
"""Process MAC address by HP switch"""
def __init__(self, host, credential):
self.host = host
self.credential = credential
def process_data(self, oper='SCAN'):
"""Dynamically call the function according 'oper'
:param oper: operation of data processing
"""
func_name = oper.lower()
return getattr(self, func_name)()
def scan(self):
"""
Implemnets the scan method in BasePlugin class. In this mac module,
mac addesses were retrieved by snmpwalk python lib.
"""
walk_result = utils.snmp_walk(self.host, self.credential,
"BRIDGE-MIB::dot1dTpFdbPort")
if not walk_result:
return None
mac_list = []
for result in walk_result:
if not result or result['value'] == str(0):
continue
temp = {}
mac_numbers = result['iid'].split('.')
temp['mac'] = self._get_mac_address(mac_numbers)
temp['port'] = self._get_port(result['value'])
temp['vlan'] = self._get_vlan_id(temp['port'])
mac_list.append(temp)
return mac_list
def _get_vlan_id(self, port):
"""Get vlan Id"""
oid = '.'.join(('Q-BRIDGE-MIB::dot1qPvid', port))
vlan_id = utils.snmp_get(self.host, self.credential, oid).strip()
return vlan_id
def _get_port(self, if_index):
"""Get port number"""
if_name = '.'.join(('ifName', if_index))
port = utils.snmp_get(self.host, self.credential, if_name).strip()
return port
def _convert_to_hex(self, integer):
"""Convert the integer from decimal to hex"""
hex_string = str(hex(int(integer)))[2:]
length = len(hex_string)
if length == 1:
hex_string = str(0) + hex_string
return hex_string
def _get_mac_address(self, mac_numbers):
"""Assemble mac address from the list"""
mac = ""
for num in mac_numbers:
num = self._convert_to_hex(num)
mac = ':'.join((mac, num))
mac = mac[1:]
return mac
pass

View File

@ -1,50 +1,18 @@
"""Huawei Switch"""
import re
import logging
from compass.hdsdiscovery import base
from compass.hdsdiscovery import utils
#Vendor_loader will load vendor instance by CLASS_NAME
CLASS_NAME = "Huawei"
class Huawei(base.BaseVendor):
class Huawei(base.BaseSnmpVendor):
"""Huawei switch"""
def __init__(self):
base.BaseSnmpVendor.__init__(self, ["huawei"])
self.__name = "huawei"
def is_this_vendor(self, host, credential):
"""
Determine if the hostname is accociated witH this vendor.
This example will use snmp sysDescr OID ,regex to extract
the vendor's name ,and then compare with self.name variable.
:param host: swtich's IP address
:param credential: credential to access switch
"""
if not utils.valid_ip_format(host):
#invalid ip address
return False
if "Version" not in credential or "Community" not in credential:
# The format of credential is incompatible with this vendor
error_msg = "[huawei]Missing 'Version' or 'Community' in %r"
logging.error(error_msg, credential)
return False
sys_info = utils.snmp_get(host, credential, "sysDescr.0")
if not sys_info:
return False
if re.search(r"\b" + re.escape(self.__name) + r"\b", sys_info.lower()):
return True
return False
@property
def name(self):
"""Return switch name"""

View File

@ -1,111 +1,44 @@
import subprocess
import logging
from compass.hdsdiscovery import utils
from compass.hdsdiscovery import base
from compass.hdsdiscovery.base import BaseSnmpMacPlugin
CLASS_NAME = "Mac"
class Mac(base.BasePlugin):
class Mac(BaseSnmpMacPlugin):
"""Processes MAC address"""
def __init__(self, host, credential):
self.mac_mib_obj = 'HUAWEI-L2MAM-MIB::hwDynFdbPort'
self.host = host
self.credential = credential
def process_data(self, oper="SCAN"):
"""
Dynamically call the function according 'oper'
:param oper: operation of data processing
"""
func_name = oper.lower()
return getattr(self, func_name)()
super(Mac, self).__init__(host, credential,
'HUAWEI-L2MAM-MIB::hwDynFdbPort')
def scan(self):
"""
Implemnets the scan method in BasePlugin class. In this mac module,
mac addesses were retrieved by snmpwalk commandline.
"""
results = utils.snmpwalk_by_cl(self.host, self.credential, self.oid)
version = self.credential['Version']
community = self.credential['Community']
if version == 2:
# Command accepts 1|2c|3 as version arg
version = '2c'
cmd = 'snmpwalk -v%s -Cc -c %s -O b %s %s' % \
(version, community, self.host, self.mac_mib_obj)
try:
sub_p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
result = []
for line in sub_p.stdout.readlines():
if not line or line == '\n':
continue
temp = {}
arr = line.split(" ")
temp['iid'] = arr[0].split('.', 1)[-1]
temp['value'] = arr[-1]
result.append(temp)
return self._process_mac(result)
except:
if not results:
logging.info("[Huawei][mac] No results returned from SNMP walk!")
return None
def _process_mac(self, walk_result):
"""Get mac addresses from snmpwalk result"""
mac_list = []
for entity in walk_result:
iid = entity['iid']
ifIndex = entity['value']
numbers = iid.split('.')
mac = self._get_mac_address(numbers, 6)
for entity in results:
# The format of 'iid' is like '248.192.1.214.34.15.31.1.48'
# The first 6 numbers will be the MAC address
# The 7th number is its vlan ID
numbers = entity['iid'].split('.')
mac = self.get_mac_address(numbers[:6])
vlan = numbers[6]
port = self._get_port(ifIndex)
port = self.get_port(entity['value'])
attri_dict_temp = {}
attri_dict_temp['port'] = port
attri_dict_temp['mac'] = mac
attri_dict_temp['vlan'] = vlan
mac_list.append(attri_dict_temp)
tmp = {}
tmp['port'] = port
tmp['mac'] = mac
tmp['vlan'] = vlan
mac_list.append(tmp)
return mac_list
def _get_port(self, if_index):
"""Get port number by using snmpget and OID 'IfName'
:param int if_index:the index of 'IfName'
"""
if_name = '.'.join(('ifName', if_index))
result = utils.snmp_get(self.host, self.credential, if_name)
"""result variable will be like: GigabitEthernet0/0/23"""
port = result.split("/")[2]
return port
def _convert_to_hex(self, integer):
"""Convert the integer from decimal to hex"""
hex_string = str(hex(int(integer)))[2:]
length = len(hex_string)
if length == 1:
hex_string = str(0) + hex_string
return hex_string
# Get Mac address: The first 6th numbers in the list
def _get_mac_address(self, iid_numbers, length):
"""Assemble mac address from the list"""
mac = ""
for index in range(length):
num = self._convert_to_hex(iid_numbers[index])
mac = ':'.join((mac, num))
mac = mac[1:]
return mac

View File

@ -15,20 +15,22 @@ class OVSwitch(base.BaseVendor):
def __init__(self):
self.__name = "Open vSwitch"
def is_this_vendor(self, host, credential):
def is_this_vendor(self, host, credential, sys_info):
"""Determine if the hostname is accociated witH this vendor.
:param host: swtich's IP address
:param credential: credential to access switch
"""
if "username" in credential and "password" in credential:
if utils.is_valid_ssh_credential(credential):
user = credential['username']
pwd = credential['password']
else:
logging.error('either username or password key is not in %s',
credential)
msg = ("[OVSwitch]The format of credential %r is not for SSH "
"or incorrect Keywords! " % credential)
logging.info(msg)
return False
cmd = "ovs-vsctl -V"
result = None
try:

View File

View File

@ -0,0 +1,19 @@
"""Vendor: Pica8"""
from compass.hdsdiscovery import base
#Vendor_loader will load vendor instance by CLASS_NAME
CLASS_NAME = 'Pica8'
class Pica8(base.BaseSnmpVendor):
"""Pica8 switch object"""
def __init__(self):
base.BaseSnmpVendor.__init__(self, ['pica8'])
self._name = 'pica8'
@property
def name(self):
"""Get 'name' proptery"""
return self._name

View File

View File

@ -0,0 +1,10 @@
"""Pica8 Switch Mac module"""
from compass.hdsdiscovery.base import BaseSnmpMacPlugin
CLASS_NAME = 'Mac'
class Mac(BaseSnmpMacPlugin):
"""Process MAC address by Pica8 switch"""
pass

View File

@ -1,4 +1,3 @@
import logging
import simplejson as json
from copy import deepcopy
from celery import current_app
@ -15,6 +14,7 @@ from compass.db.model import ClusterHost
from compass.db.model import HostState
from compass.db.model import Adapter
from compass.db.model import Role
from compass.db.model import SwitchConfig
class ApiTestCase(unittest2.TestCase):
@ -46,7 +46,7 @@ class ApiTestCase(unittest2.TestCase):
class TestSwtichMachineAPI(ApiTestCase):
SWITCH_RESP_TPL = {"state": "not_reached",
SWITCH_RESP_TPL = {"state": "under_monitoring",
"ip": "",
"link": {"href": "",
"rel": "self"},
@ -58,6 +58,7 @@ class TestSwtichMachineAPI(ApiTestCase):
with database.session() as session:
test_switch = Switch(ip=self.SWITCH_IP_ADDRESS1)
test_switch.credential = self.SWITCH_CREDENTIAL
test_switch.state = 'under_monitoring'
session.add(test_switch)
def tearDown(self):
@ -156,7 +157,6 @@ class TestSwtichMachineAPI(ApiTestCase):
# Non-exist switch id
url = '/switches/1000'
rv = self.app.get(url)
logging.info('[test_get_switch_by_id] url %s', url)
self.assertEqual(rv.status_code, 404)
correct_url = '/switches/1'
@ -187,6 +187,8 @@ class TestSwtichMachineAPI(ApiTestCase):
data = {'switch': {'credential': credential}}
rv = self.app.put(url, data=json.dumps(data))
self.assertEqual(rv.status_code, 202)
self.assertEqual(json.loads(rv.get_data())['switch']['state'],
'repolling')
def test_delete_switch(self):
url = '/switches/1'
@ -214,12 +216,32 @@ class TestSwtichMachineAPI(ApiTestCase):
def test_get_machineList(self):
#Prepare testing data
with database.session() as session:
switch_config = [
SwitchConfig(ip='10.10.10.1', filter_port='6'),
SwitchConfig(ip='10.10.10.1', filter_port='7')
]
session.add_all(switch_config)
machines = [Machine(mac='00:27:88:0c:01', port='1', vlan='1',
switch_id=1),
Machine(mac='00:27:88:0c:02', port='2', vlan='1',
switch_id=1),
Machine(mac='00:27:88:0c:03', port='3', vlan='1',
switch_id=1),
Machine(mac='00:27:88:01:04', port='4', vlan='1',
switch_id=1),
Machine(mac='00:27:88:01:05', port='5', vlan='1',
switch_id=1),
Machine(mac='00:27:88:01:06', port='6', vlan='1',
switch_id=1),
Machine(mac='00:27:88:01:07', port='7', vlan='1',
switch_id=1),
Machine(mac='00:27:88:01:08', port='8', vlan='1',
switch_id=1),
Machine(mac='00:27:88:01:09', port='9', vlan='1',
switch_id=1),
Machine(mac='00:27:88:01:10', port='10', vlan='1',
switch_id=1),
Machine(mac='00:27:88:0c:04', port='3', vlan='1',
switch_id=2),
Machine(mac='00:27:88:0c:05', port='4', vlan='2',
@ -228,13 +250,15 @@ class TestSwtichMachineAPI(ApiTestCase):
switch_id=3)]
session.add_all(machines)
testList = [{'url': '/machines', 'expected': 6},
testList = [{'url': '/machines', 'expected': 11},
{'url': '/machines?limit=3', 'expected': 3},
{'url': '/machines?limit=50', 'expected': 6},
{'url': '/machines?limit=50', 'expected': 11},
{'url': '/machines?switchId=1&vladId=1&port=2',
'expected': 1},
{'url': '/machines?switchId=1&vladId=1&limit=2',
'expected': 2},
{'url': '/machines?switchId=1', 'expected': 8},
{'url': '/machines?switchId=1&port=6', 'expected': 1},
{'url': '/machines?switchId=4', 'expected': 0}]
for test in testList:
@ -267,31 +291,33 @@ class TestClusterAPI(ApiTestCase):
"ip_end": "192.168.1.200",
"netmask": "255.255.255.0",
"gateway": "192.168.1.1",
"vlan": "",
"nic": "eth0",
"promisc": 1},
"tenant": {
"ip_start": "192.168.1.100",
"ip_end": "192.168.1.200",
"netmask": "255.255.255.0",
"gateway": "",
"nic": "eth1",
"promisc": 0},
"public": {
"ip_start": "192.168.1.100",
"ip_end": "192.168.1.200",
"netmask": "255.255.255.0",
"gateway": "",
"nic": "eth3",
"promisc": 1},
"storage": {
"ip_start": "192.168.1.100",
"ip_end": "192.168.1.200",
"netmask": "255.255.255.0",
"gateway": "",
"nic": "eth3",
"promisc": 1}},
"global": {
"gateway": "192.168.1.1",
"proxy": "",
"ntp_sever": "",
"ntp_server": "",
"nameservers": "8.8.8.8",
"search_path": "ods.com,ods1.com"}}
@ -357,7 +383,7 @@ class TestClusterAPI(ApiTestCase):
Cluster(name="cluster_04")]
session.add_all(clusters_list)
session.flush()
url = "/clusters"
rv = self.app.get(url)
data = json.loads(rv.get_data())
@ -371,6 +397,11 @@ class TestClusterAPI(ApiTestCase):
url = '/clusters/1/security'
rv = self.app.put(url, data=json.dumps(security))
self.assertEqual(rv.status_code, 200)
with database.session() as session:
cluster_security_config = session.query(Cluster.security_config)\
.filter_by(id=1).first()[0]
self.assertDictEqual(self.SECURITY_CONFIG,
json.loads(cluster_security_config))
# b. Update a non-existing cluster's resource
url = '/clusters/1000/security'
@ -383,26 +414,76 @@ class TestClusterAPI(ApiTestCase):
self.assertEqual(rv.status_code, 400)
# d. Security config is invalid -- some required field is null
security['security']['server_credentials']['username'] = None
rv = self.app.put(url, data=json.dumps(security))
url = "/clusters/1/security"
invalid_security = deepcopy(security)
invalid_security['security']['server_credentials']['username'] = None
rv = self.app.put(url, data=json.dumps(invalid_security))
self.assertEqual(rv.status_code, 400)
# e. Security config is invalid -- keyword is incorrect
security['security']['xxxx'] = {'xxx': 'xxx'}
rv = self.app.put(url, data=json.dumps(security))
invalid_security = deepcopy(security)
invalid_security['security']['xxxx'] = {'xxx': 'xxx'}
rv = self.app.put(url, data=json.dumps(invalid_security))
self.assertEqual(rv.status_code, 400)
# f. Security config is invalid -- missing keyword
invalid_security = deepcopy(security)
del invalid_security["security"]["server_credentials"]
rv = self.app.put(url, data=json.dumps(invalid_security))
self.assertEqual(rv.status_code, 400)
# g. Security config is invalid -- missing subkey keyword
invalid_security = deepcopy(security)
del invalid_security["security"]["server_credentials"]["username"]
rv = self.app.put(url, data=json.dumps(invalid_security))
self.assertEqual(rv.status_code, 400)
def test_put_cluster_networking_resource(self):
networking = {"networking" : self.NETWORKING_CONFIG}
networking = {"networking": self.NETWORKING_CONFIG}
url = "/clusters/1/networking"
rv = self.app.put(url, data=json.dumps(networking))
self.assertEqual(rv.status_code, 200)
# Missing some required keyword in interfaces section
invalid_config = deepcopy(networking)
del invalid_config["networking"]["interfaces"]["management"]["nic"]
rv = self.app.put(url, data=json.dumps(invalid_config))
self.assertEqual(rv.status_code, 400)
invalid_config = deepcopy(networking)
del invalid_config["networking"]["interfaces"]["management"]
rv = self.app.put(url, data=json.dumps(invalid_config))
self.assertEqual(rv.status_code, 400)
invalid_config = deepcopy(networking)
invalid_config["networking"]["interfaces"]["xxx"] = {}
rv = self.app.put(url, data=json.dumps(invalid_config))
self.assertEqual(rv.status_code, 400)
# Missing some required keyword in global section
invalid_config = deepcopy(networking)
del invalid_config["networking"]["global"]["gateway"]
rv = self.app.put(url, data=json.dumps(invalid_config))
self.assertEqual(rv.status_code, 400)
# Invalid value in interfaces section
invalid_config = deepcopy(networking)
invalid_config["networking"]["interfaces"]["tenant"]["nic"] = "eth0"
rv = self.app.put(url, data=json.dumps(invalid_config))
self.assertEqual(rv.status_code, 400)
# Invalid value in global section
invalid_config = deepcopy(networking)
invalid_config["networking"]["global"]["nameservers"] = "*.*.*.*,"
rv = self.app.put(url, data=json.dumps(invalid_config))
self.assertEqual(rv.status_code, 400)
def test_get_cluster_resource(self):
# Test only one resource - secuirty as an example
# Test resource
with database.session() as session:
cluster = session.query(Cluster).filter_by(id=1).first()
cluster.security = self.SECURITY_CONFIG
cluster.networking = self.NETWORKING_CONFIG
# a. query secuirty config by cluster id
url = '/clusters/1/security'
@ -411,6 +492,12 @@ class TestClusterAPI(ApiTestCase):
self.assertEqual(rv.status_code, 200)
self.assertDictEqual(data['security'], self.SECURITY_CONFIG)
url = '/clusters/1/networking'
rv = self.app.get(url)
data = json.loads(rv.get_data())
self.assertEqual(rv.status_code, 200)
self.assertDictEqual(data['networking'], self.NETWORKING_CONFIG)
# b. query a nonsupported resource, return 400
url = '/clusters/1/xxx'
rv = self.app.get(url)
@ -428,7 +515,11 @@ class TestClusterAPI(ApiTestCase):
machines = [Machine(mac='00:27:88:0c:01'),
Machine(mac='00:27:88:0c:02'),
Machine(mac='00:27:88:0c:03'),
Machine(mac='00:27:88:0c:04')]
Machine(mac='00:27:88:0c:04'),
Machine(mac='00:27:88:0c:05'),
Machine(mac='00:27:88:0c:06'),
Machine(mac='00:27:88:0c:07'),
Machine(mac='00:27:88:0c:08')]
clusters = [Cluster(name='cluster_02')]
session.add_all(machines)
session.add_all(clusters)
@ -466,32 +557,49 @@ class TestClusterAPI(ApiTestCase):
request = {'addHosts': [1, 2, 3]}
rv = self.app.post(url, data=json.dumps(request))
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.get_data())
self.assertEqual(len(data['cluster_hosts']), 3)
total_hosts = 0
with database.session() as session:
total_hosts = session.query(func.count(ClusterHost.id))\
.filter_by(cluster_id=1).scalar()
data = json.loads(rv.get_data())
self.assertEqual(len(data['cluster_hosts']), total_hosts)
self.assertEqual(total_hosts, 3)
# 4. try to remove some hosts which do not exists
request = {'removeHosts': [1, 1000, 1001]}
# 4. try to remove some hosts not existing and in different cluster
request = {'removeHosts': [1, 2, 3, 1000, 1001]}
rv = self.app.post(url, data=json.dumps(request))
self.assertEqual(rv.status_code, 404)
data = json.loads(rv.get_data())
self.assertEqual(len(data['failedHosts']), 2)
self.assertEqual(len(data['failedHosts']), 3)
with database.session() as session:
count = session.query(func.count(ClusterHost.id))\
.filter_by(cluster_id=1).scalar()
self.assertEqual(count, 3)
# 5. sucessfully remove requested hosts
request = {'removeHosts': [1, 2]}
request = {'removeHosts': [2, 3]}
rv = self.app.post(url, data=json.dumps(request))
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.get_data())
self.assertEqual(len(data['cluster_hosts']), 2)
with database.session() as session:
count = session.query(func.count(ClusterHost.id))\
.filter_by(cluster_id=1).scalar()
self.assertEqual(count, 1)
# 6. Test 'replaceAllHosts' action on cluster_01
request = {'replaceAllHosts': [1, 2, 3]}
request = {'replaceAllHosts': [5, 6, 7]}
rv = self.app.post(url, data=json.dumps(request))
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.get_data())
self.assertEqual(len(data['cluster_hosts']), 3)
with database.session() as session:
count = session.query(func.count(ClusterHost.id))\
.filter_by(cluster_id=1).scalar()
self.assertEqual(count, 3)
# 7. Test 'deploy' action on cluster_01
request = {'deploy': {}}
request = {'deploy': []}
rv = self.app.post(url, data=json.dumps(request))
self.assertEqual(rv.status_code, 202)
@ -499,7 +607,7 @@ class TestClusterAPI(ApiTestCase):
rv = self.app.post(url, data=json.dumps(request))
self.assertEqual(rv.status_code, 400)
# 9. Try to deploy cluster_02 which no host
# 9. Try to deploy cluster_02 which no host in
url = '/clusters/2/action'
with database.session() as session:
session.query(ClusterHost).filter_by(cluster_id=2)\
@ -509,6 +617,36 @@ class TestClusterAPI(ApiTestCase):
rv = self.app.post(url, data=json.dumps(request))
self.assertEqual(rv.status_code, 404)
# 10. Try to add a new host to cluster_01 and deploy it
with database.session() as session:
cluster = session.query(Cluster).filter_by(id=1).first()
cluster.mutable = True
hosts = session.query(ClusterHost).filter_by(cluster_id=1).all()
for host in hosts:
host.mutable = True
url = '/clusters/1/action'
# add another machine as a new host into cluster_01
request = json.dumps({"addHosts": [8]})
rv = self.app.post(url, data=request)
host_id = json.loads(rv.get_data())["cluster_hosts"][0]["id"]
deploy_request = json.dumps({"deploy": [host_id]})
rv = self.app.post(url, data=deploy_request)
self.assertEqual(202, rv.status_code)
expected_deploy_result = {
"cluster": {
"cluster_id": 1,
"url": "/clusters/1/progress"
},
"hosts": [
{"host_id": 5,
"url": "/cluster_hosts/5/progress"}
]
}
data = json.loads(rv.get_data())["deployment"]
self.assertDictEqual(expected_deploy_result, data)
class ClusterHostAPITest(ApiTestCase):
@ -526,10 +664,18 @@ class ClusterHostAPITest(ApiTestCase):
clusters_list = [Cluster(name='cluster_01'),
Cluster(name='cluster_02')]
session.add_all(clusters_list)
hosts_list = [ClusterHost(hostname='host_02', cluster_id=1),
ClusterHost(hostname='host_03', cluster_id=1),
ClusterHost(hostname='host_04', cluster_id=2)]
host = ClusterHost(hostname='host_01', cluster_id=1)
machines_list = [Machine(mac='00:27:88:0c:01'),
Machine(mac='00:27:88:0c:02'),
Machine(mac='00:27:88:0c:03'),
Machine(mac='00:27:88:0c:04')]
session.add_all(machines_list)
hosts_list = [
ClusterHost(hostname='host_02', cluster_id=1, machine_id=2),
ClusterHost(hostname='host_03', cluster_id=1, machine_id=3),
ClusterHost(hostname='host_04', cluster_id=2, machine_id=4)
]
host = ClusterHost(hostname='host_01', cluster_id=1, machine_id=1)
host.config_data = json.dumps(self.test_config_data)
session.add(host)
session.add_all(hosts_list)
@ -556,6 +702,8 @@ class ClusterHostAPITest(ApiTestCase):
expected_config['hostname'] = 'host_01'
expected_config['clusterid'] = 1
expected_config['clustername'] = 'cluster_01'
expected_config['networking']['interfaces']['management']['mac'] \
= "00:27:88:0c:01"
self.assertDictEqual(config, expected_config)
def test_clusterHost_put_config(self):
@ -593,7 +741,6 @@ class ClusterHostAPITest(ApiTestCase):
url = 'clusterhosts/1/config/ip'
rv = self.app.delete(url)
self.assertEqual(200, rv.status_code)
expected_config = deepcopy(self.test_config_data)
expected_config['networking']['interfaces']['management']['ip'] = ''
with database.session() as session:
@ -742,13 +889,229 @@ class TestAdapterAPI(ApiTestCase):
"rel": "self"}
}
self.assertDictEqual(execpted_result, data['adapters'][0])
url = '/adapters'
rv = self.app.get(url)
data = json.loads(rv.get_data())
self.assertEqual(200, rv.status_code)
self.assertEqual(2, len(data['adapters']))
class TestAPIWorkFlow(ApiTestCase):
CLUSTER_SECURITY_CONFIG = {
"security": {
"server_credentials": {
"username": "admin",
"password": "admin"},
"service_credentials": {
"username": "admin",
"password": "admin"},
"console_credentials": {
"username": "admin",
"password": "admin"}
}
}
CLUSTER_NETWORKING_CONFIG = {
"networking": {
"interfaces": {
"management": {
"ip_start": "10.120.8.100",
"ip_end": "10.120.8.200",
"netmask": "255.255.255.0",
"gateway": "",
"nic": "eth0",
"promisc": 1
},
"tenant": {
"ip_start": "192.168.10.100",
"ip_end": "192.168.10.200",
"netmask": "255.255.255.0",
"gateway": "",
"nic": "eth1",
"promisc": 0
},
"public": {
"ip_start": "12.145.68.100",
"ip_end": "12.145.68.200",
"netmask": "255.255.255.0",
"gateway": "",
"nic": "eth2",
"promisc": 0
},
"storage": {
"ip_start": "172.29.8.100",
"ip_end": "172.29.8.200",
"netmask": "255.255.255.0",
"gateway": "",
"nic": "eth3",
"promisc": 0
}
},
"global": {
"nameservers": "8.8.8.8",
"search_path": "ods.com",
"gateway": "192.168.1.1",
"proxy": "http://127.0.0.1:3128",
"ntp_server": "127.0.0.1"
}
}
}
CLUSTER_PARTITION_CONFIG = {
"partition": "/home 20%;"
}
CLUSTERHOST_CONFIG = {
"hostname": "",
"networking": {
"interfaces": {
"management": {
"ip": ""
}
}
},
"roles": ["base"]
}
def setUp(self):
super(TestAPIWorkFlow, self).setUp()
#Prepare test data
with database.session() as session:
# Populate switch info to DB
switch = Switch(ip="192.168.2.1",
credential={"version": "2c",
"community": "public"},
vendor="huawei",
state="under_monitoring")
session.add(switch)
# Populate machines info to DB
machines = [
Machine(mac='00:27:88:0c:a6', port='1', vlan='1', switch_id=1),
Machine(mac='00:27:88:0c:a7', port='2', vlan='1', switch_id=1),
Machine(mac='00:27:88:0c:a8', port='3', vlan='1', switch_id=1),
]
session.add_all(machines)
def tearDown(self):
super(TestAPIWorkFlow, self).tearDown()
def test_work_flow(self):
# Polling switch: mock post switch
# url = '/switches'
# data = {"ip": "192.168.2.1",
# "credential": {"version": "2c", "community": "public"}}
# self.app.post(url, json.dumps(data))
# Get machines once polling switch done. If switch state changed to
# "under_monitoring" state.
url = '/switches/1'
switch_state = "initialized"
while switch_state != "under_monitoring":
rv = self.app.get(url)
switch_state = json.loads(rv.get_data())['switch']['state']
url = '/machines?switchId=1'
rv = self.app.get(url)
self.assertEqual(200, rv.status_code)
machines = json.loads(rv.get_data())['machines']
# Create a Cluster and get cluster id from response
url = '/clusters'
data = {
"cluster": {
"name": "cluster_01",
"adapter_id": 1
}
}
rv = self.app.post(url, data=json.dumps(data))
self.assertEqual(200, rv.status_code)
cluster_id = json.loads(rv.get_data())['cluster']['id']
# Add machines as hosts of the cluster
url = '/clusters/%s/action' % cluster_id
machines_id = []
for m in machines:
machines_id.append(m["id"])
data = {"addHosts": machines_id}
rv = self.app.post(url, data=json.dumps(data))
self.assertEqual(200, rv.status_code)
hosts_info = json.loads(rv.get_data())["cluster_hosts"]
# Update cluster security configuration
url = '/clusters/%s/security' % cluster_id
security_config = json.dumps(self.CLUSTER_SECURITY_CONFIG)
rv = self.app.put(url, data=security_config)
self.assertEqual(200, rv.status_code)
# Update cluster networking configuration
url = '/clusters/%s/networking' % cluster_id
networking_config = json.dumps(self.CLUSTER_NETWORKING_CONFIG)
rv = self.app.put(url, data=networking_config)
self.assertEqual(200, rv.status_code)
# Update cluster partition configuration
url = '/clusters/%s/partition' % cluster_id
partition_config = json.dumps(self.CLUSTER_PARTITION_CONFIG)
rv = self.app.put(url, data=partition_config)
self.assertEqual(200, rv.status_code)
# Put cluster host config individually
hosts_configs = [
deepcopy(self.CLUSTERHOST_CONFIG),
deepcopy(self.CLUSTERHOST_CONFIG),
deepcopy(self.CLUSTERHOST_CONFIG)
]
names = ["host_01", "host_02", "host_03"]
mgmt_ips = ["10.120.8.100", "10.120.8.101", "10.120.8.102"]
for config, name, ip in zip(hosts_configs, names, mgmt_ips):
config["hostname"] = name
config["networking"]["interfaces"]["management"]["ip"] = ip
for config, host_info in zip(hosts_configs, hosts_info):
host_id = host_info["id"]
url = 'clusterhosts/%d/config' % host_id
rv = self.app.put(url, data=json.dumps(config))
self.assertEqual(200, rv.status_code)
# deploy the Cluster
url = "/clusters/%d/action" % cluster_id
data = json.dumps({"deploy": []})
self.app.post(url, data=data)
self.assertEqual(200, rv.status_code)
# Verify the final cluster configuration
expected_cluster_config = {}
expected_cluster_config.update(self.CLUSTER_SECURITY_CONFIG)
expected_cluster_config.update(self.CLUSTER_NETWORKING_CONFIG)
expected_cluster_config.update(self.CLUSTER_PARTITION_CONFIG)
expected_cluster_config["clusterid"] = cluster_id
expected_cluster_config["clustername"] = "cluster_01"
with database.session() as session:
cluster = session.query(Cluster).filter_by(id=cluster_id).first()
config = cluster.config
self.assertDictEqual(config, expected_cluster_config)
# Verify each host configuration
for host_info, excepted in zip(hosts_info, hosts_configs):
machine_id = host_info["machine_id"]
machine = session.query(Machine).filter_by(id=machine_id)\
.first()
mac = machine.mac
excepted["clusterid"] = cluster_id
excepted["clustername"] = "cluster_01"
excepted["hostid"] = host_info["id"]
excepted["networking"]["interfaces"]["management"]["mac"] = mac
host = session.query(ClusterHost)\
.filter_by(id=host_info["id"]).first()
self.maxDiff = None
self.assertDictEqual(host.config, excepted)
if __name__ == '__main__':
unittest2.main()

View File

@ -0,0 +1,92 @@
import unittest2
from mock import patch
from compass.hdsdiscovery.base import BaseSnmpVendor
from compass.hdsdiscovery.base import BaseSnmpMacPlugin
class MockSnmpVendor(BaseSnmpVendor):
def __init__(self):
BaseSnmpVendor.__init__(self, ["MockVendor", "FakeVendor"])
class TestBaseSnmpMacPlugin(unittest2.TestCase):
def setUp(self):
self.test_plugin = BaseSnmpMacPlugin('12.0.0.1',
{'version': '2c',
'community': 'public'})
def tearDown(self):
del self.test_plugin
@patch('compass.hdsdiscovery.utils.snmpget_by_cl')
def test_get_port(self, mock_snmpget):
mock_snmpget.return_value = 'IF-MIB::ifName.4 = STRING: ge-1/1/4'
result = self.test_plugin.get_port('4')
self.assertEqual('4', result)
@patch('compass.hdsdiscovery.utils.snmpget_by_cl')
def test_get_vlan_id(self, mock_snmpget):
# Port is None
self.assertIsNone(self.test_plugin.get_vlan_id(None))
# Port is not None
mock_snmpget.return_value = 'Q-BRIDGE-MIB::dot1qPvid.4 = Gauge32: 100'
result = self.test_plugin.get_vlan_id('4')
self.assertEqual('100', result)
def test_get_mac_address(self):
# Correct input for mac numbers
mac_numbers = '0.224.129.230.57.173'.split('.')
mac = self.test_plugin.get_mac_address(mac_numbers)
self.assertEqual('00:e0:81:e6:39:ad', mac)
# Incorrct input for mac numbers
mac_numbers = '0.224.129.230.57'.split('.')
mac = self.test_plugin.get_mac_address(mac_numbers)
self.assertIsNone(mac)
class BaseTest(unittest2.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_base_snmp_vendor(self):
fake = MockSnmpVendor()
credential = {"version": "2c",
"community": "public"}
is_vendor = fake.is_this_vendor("12.0.0.1", credential,
"FakeVendor 1.1")
self.assertTrue(is_vendor)
# check case-insensitive match
self.assertFalse(fake.is_this_vendor("12.0.0.1", credential,
"fakevendor1.1"))
# breaks word-boudary match
self.assertFalse(fake.is_this_vendor("12.0.0.1", credential,
"FakeVendor1.1"))
# Not SNMP credentials
self.assertFalse(fake.is_this_vendor("12.0.0.1",
{"username": "root",
"password": "test123"},
"fakevendor1.1"))
# Not SNMP v2 credentials
self.assertFalse(fake.is_this_vendor("12.0.0.1",
{"version": "v1",
"community": "public"},
"fakevendor1.1"))
if __name__ == '__main__':
unittest2.main()

View File

@ -1,121 +1,61 @@
import unittest2
from mock import patch
from compass.hdsdiscovery.hdmanager import HDManager
from compass.hdsdiscovery.vendors.huawei.huawei import Huawei
from compass.hdsdiscovery.vendors.huawei.plugins.mac import Mac
class HuaweiTest(unittest2.TestCase):
def setUp(self):
self.huawei = Huawei()
self.correct_host = '172.29.8.40'
self.correct_credentials = {'Version': 'v2c', 'Community': 'public'}
self.correct_host = '12.23.1.1'
self.correct_credentials = {'version': '2c', 'community': 'public'}
self.sys_info = 'Huawei Technologies'
def tearDown(self):
del self.huawei
@patch('compass.hdsdiscovery.utils.snmp_get')
def test_IsThisVendor_WithIncorrectIPFormat(self, snmp_get_mock):
snmp_get_mock.return_value = None
#host is incorrest IP address format
self.assertFalse(self.huawei.is_this_vendor('500.10.1.2000',
self.correct_credentials))
@patch('compass.hdsdiscovery.utils.snmp_get')
def test_IsThisVendor_WithWrongCredential(self, snmp_get_mock):
snmp_get_mock.return_value = None
def test_is_this_vendor(self):
#Credential's keyword is incorrect
self.assertFalse(
self.huawei.is_this_vendor(self.correct_host,
{'username': 'root',
'Community': 'public'}))
'password': 'root'},
self.sys_info))
#Incorrect Version
#Incorrect version
self.assertFalse(
self.huawei.is_this_vendor(self.correct_host,
{'Version': 'v1',
'Community': 'public'}))
{'version': 'v1',
'community': 'public'},
self.sys_info))
#Incorrect Community
self.assertFalse(
#Correct vendor
self.assertTrue(
self.huawei.is_this_vendor(self.correct_host,
{'Version': 'v2c',
'Community': 'private'}))
@patch('compass.hdsdiscovery.utils.snmp_get')
def test_IsThisVendor_WithCorrectInput(self, snmp_get_mock):
snmp_get_mock.return_value = "Huawei"
self.assertTrue(self.huawei.is_this_vendor(self.correct_host,
self.correct_credentials))
@patch('compass.hdsdiscovery.utils.snmp_get')
def test_IsThisVendor_WithIncorrectVendor(self, snmp_get_mock):
snmp_get_mock.return_value = None
self.assertFalse(
self.huawei.is_this_vendor('1.1.1.1',
{'Version': 'v1',
'Community': 'private'}))
from compass.hdsdiscovery.vendors.huawei.plugins.mac import Mac
self.correct_credentials,
self.sys_info))
class HuaweiMacTest(unittest2.TestCase):
def setUp(self):
host = '172.29.8.40'
credential = {'Version': 'v2c', 'Community': 'public'}
self.mac = Mac(host, credential)
host = '12.23.1.1'
credential = {'version': '2c', 'community': 'public'}
self.mac_plugin = Mac(host, credential)
def tearDown(self):
del self.mac
del self.mac_plugin
def test_ProcessData_Operation(self):
# GET operation haven't been implemeneted.
self.assertIsNone(self.mac.process_data('GET'))
self.assertIsNone(self.mac_plugin.process_data('GET'))
from compass.hdsdiscovery.vendors.ovswitch.ovswitch import OVSwitch
from compass.hdsdiscovery.vendors.ovswitch.plugins.mac import Mac as OVSMac
class OVSTest(unittest2.TestCase):
def setUp(self):
self.host = '10.145.88.160'
self.credential = {'username': 'root', 'password': 'huawei'}
self.ovswitch = OVSwitch()
def tearDown(self):
del self.ovswitch
@patch('compass.hdsdiscovery.utils.ssh_remote_execute')
def test_isThisVendor_withIncorrectInput(self, ovs_mock):
ovs_mock.return_value = []
# Incorrect host ip
self.assertFalse(self.ovswitch.is_this_vendor('1.1.1.1',
self.credential))
# Incorrect credential
self.assertFalse(
self.ovswitch.is_this_vendor(self.host,
{'username': 'xxx',
'password': 'xxx'}))
# not Open vSwitch
self.assertFalse(
self.ovswitch.is_this_vendor(self.host,
{'Version': 'xxx',
'Community': 'xxx'}))
# not Open vSwitch, snmpv3
self.assertFalse(
self.ovswitch.is_this_vendor(self.host,
{'Version': 'xxx',
'Community': 'xxx',
'username': 'xxx',
'password': 'xxx'}))
class OVSMacTest(unittest2.TestCase):
def setUp(self):
self.host = '10.145.88.160'
@ -134,74 +74,81 @@ class OVSMacTest(unittest2.TestCase):
del mac_instance
from compass.hdsdiscovery.vendors.hp.hp import Hp
class HpTest(unittest2.TestCase):
def setUp(self):
self.host = '10.145.88.140'
self.credential = {'Version': 'v2c', 'Community': 'public'}
self.hpSwitch = Hp()
def tearDown(self):
del self.hpSwitch
@patch('compass.hdsdiscovery.utils.snmp_get')
def test_IsThisVendor(self, snmpget_mock):
snmpget_mock.return_value = "ProCurve J9089A Switch 2610-48-PWR"
self.assertTrue(self.hpSwitch.is_this_vendor(self.host,
self.credential))
snmpget_mock.return_value = None
self.assertFalse(self.hpSwitch.is_this_vendor(self.host,
self.credential))
snmpget_mock.return_value = "xxxxxxxxxxx"
self.assertFalse(self.hpSwitch.is_this_vendor(self.host,
self.credential))
from compass.hdsdiscovery.hdmanager import HDManager
class HDManagerTest(unittest2.TestCase):
def setUp(self):
self.manager = HDManager()
self.correct_host = '172.29.8.40'
self.correct_credential = {'Version': 'v2c', 'Community': 'public'}
self.ovs_host = '10.145.88.160'
self.ovs_credential = {'username': 'root', 'password': 'huawei'}
self.correct_host = '12.23.1.1'
self.correct_credential = {'version': '2c', 'community': 'public'}
def tearDown(self):
del self.manager
@patch('compass.hdsdiscovery.utils.ssh_remote_execute')
@patch('compass.hdsdiscovery.utils.snmp_get')
def test_GetVendor_WithIncorrectInput(self, snmp_get_mock, ovs_mock):
snmp_get_mock.return_value = None
ovs_mock.return_value = []
@patch('compass.hdsdiscovery.hdmanager.HDManager.get_sys_info')
def test_get_vendor(self, sys_info_mock):
# Incorrect ip
self.assertIsNone(self.manager.get_vendor('1.1.1.1',
self.correct_credential))
self.assertIsNone(self.manager.get_vendor('1.1.1.1',
self.ovs_credential))
self.assertIsNone(self.manager.get_vendor('1234.1.1.1',
self.correct_credential)[0])
# Incorrect credential
self.assertIsNone(
self.manager.get_vendor(self.correct_host,
{'Version': '1v', 'Community': 'private'}))
self.assertIsNone(
self.manager.get_vendor(self.ovs_host,
{'username': 'xxxxx', 'password': 'xxxx'}))
{'version': '1v',
'community': 'private'})[0])
# SNMP get system description Timeout
excepted_err_msg = 'Timeout: No Response from 12.23.1.1.'
sys_info_mock.return_value = (None, excepted_err_msg)
result, state, err = self.manager.get_vendor(self.correct_host,
self.correct_credential)
self.assertIsNone(result)
self.assertEqual(state, 'unreachable')
self.assertEqual(err, excepted_err_msg)
# No vendor plugin supported
excepted_err_msg = 'Not supported switch vendor!'
sys_info_mock.return_value = ('xxxxxx', excepted_err_msg)
result, state, err = self.manager.get_vendor(self.correct_host,
self.correct_credential)
self.assertIsNone(result)
self.assertEqual(state, 'notsupported')
self.assertEqual(err, excepted_err_msg)
# Found the correct vendor
sys_info = ['Huawei Versatile Routing Platform Software',
'ProCurve J9089A Switch 2610-48-PWR, revision R.11.25',
'Pica8 XorPlus Platform Software']
expected_vendor_names = ['huawei', 'hp', 'pica8']
for info, expected_vendor in zip(sys_info, expected_vendor_names):
sys_info_mock.return_value = (info, '')
result, state, err = self.manager\
.get_vendor(self.correct_host,
self.correct_credential)
self.assertEqual(result, expected_vendor)
@patch('compass.hdsdiscovery.hdmanager.HDManager.get_sys_info')
def test_is_valid_vendor(self, sys_info_mock):
def test_ValidVendor(self):
#non-exsiting vendor
self.assertFalse(self.manager.is_valid_vendor(self.correct_host,
self.correct_credential,
'xxxx'))
#No system description retrieved
sys_info_mock.return_value = (None, 'TIMEOUT')
self.assertFalse(self.manager.is_valid_vendor(self.correct_host,
self.correct_credential,
'pica8'))
#Incorrect vendor name
sys_info = 'Pica8 XorPlus Platform Software'
sys_info_mock.return_value = (sys_info, '')
self.assertFalse(self.manager.is_valid_vendor(self.correct_host,
self.correct_credential,
'huawei'))
#Correct vendor name
self.assertTrue(self.manager.is_valid_vendor(self.correct_host,
self.correct_credential,
'pica8'))
def test_Learn(self):
#non-exsiting plugin