Add options supporting DataSource identifiers in job_configs
This change adds options that allow DataSource objects to be referenced by name or uuid in the job_configs dictionary of a job_execution. If a reference to a DataSource is found, the path information replaces the reference. Note, references are partially resolved in early processing to determine whether or not a proxy user must be created. References are fully resolved in run_job(). Implements: blueprint edp-data-sources-in-job-configs Change-Id: I5be62b798b86a8aaf933c2cc6b6d5a252f0a8627
This commit is contained in:
parent
c7dc7968db
commit
8750ddc121
@ -235,6 +235,14 @@ class LocalApi(object):
|
||||
"""
|
||||
return self._manager.data_source_get_all(context, **kwargs)
|
||||
|
||||
def data_source_count(self, context, **kwargs):
|
||||
"""Count Data Sources filtered by **kwargs.
|
||||
|
||||
Uses sqlalchemy "in_" clause for any tuple values
|
||||
Uses sqlalchemy "like" clause for any string values containing %
|
||||
"""
|
||||
return self._manager.data_source_count(context, **kwargs)
|
||||
|
||||
@r.wrap(r.DataSource)
|
||||
def data_source_create(self, context, values):
|
||||
"""Create a Data Source from the values dictionary."""
|
||||
|
@ -289,6 +289,14 @@ class ConductorManager(db_base.Base):
|
||||
"""
|
||||
return self.db.data_source_get_all(context, **kwargs)
|
||||
|
||||
def data_source_count(self, context, **kwargs):
|
||||
"""Count Data Sources filtered by **kwargs.
|
||||
|
||||
Uses sqlalchemy "in_" clause for any tuple values
|
||||
Uses sqlalchemy "like" clause for any string values containing %
|
||||
"""
|
||||
return self.db.data_source_count(context, **kwargs)
|
||||
|
||||
def data_source_create(self, context, values):
|
||||
"""Create a Data Source from the values dictionary."""
|
||||
values = copy.deepcopy(values)
|
||||
|
@ -262,6 +262,15 @@ def data_source_get_all(context, **kwargs):
|
||||
return IMPL.data_source_get_all(context, **kwargs)
|
||||
|
||||
|
||||
def data_source_count(context, **kwargs):
|
||||
"""Count Data Sources filtered by **kwargs.
|
||||
|
||||
Uses sqlalchemy "in_" clause for any tuple values
|
||||
Uses sqlalchemy "like" clause for any string values containing %
|
||||
"""
|
||||
return IMPL.data_source_count(context, **kwargs)
|
||||
|
||||
|
||||
@to_dict
|
||||
def data_source_create(context, values):
|
||||
"""Create a Data Source from the values dictionary."""
|
||||
|
@ -99,6 +99,76 @@ def count_query(model, context, session=None, project_only=None):
|
||||
return model_query(sa.func.count(model.id), context, session, project_only)
|
||||
|
||||
|
||||
def in_filter(query, cls, search_opts):
|
||||
"""Add 'in' filters for specified columns.
|
||||
|
||||
Add a sqlalchemy 'in' filter to the query for any entry in the
|
||||
'search_opts' dict where the key is the name of a column in
|
||||
'cls' and the value is a tuple.
|
||||
|
||||
This allows the value of a column to be matched
|
||||
against multiple possible values (OR).
|
||||
|
||||
Return the modified query and any entries in search_opts
|
||||
whose keys do not match columns or whose values are not
|
||||
tuples.
|
||||
|
||||
:param query: a non-null query object
|
||||
:param cls: the database model class that filters will apply to
|
||||
:param search_opts: a dictionary whose key/value entries are interpreted as
|
||||
column names and search values
|
||||
:returns: a tuple containing the modified query and a dictionary of
|
||||
unused search_opts
|
||||
"""
|
||||
if not search_opts:
|
||||
return query, search_opts
|
||||
|
||||
remaining = {}
|
||||
for k, v in six.iteritems(search_opts):
|
||||
if type(v) == tuple and k in cls.__table__.columns:
|
||||
col = cls.__table__.columns[k]
|
||||
query = query.filter(col.in_(v))
|
||||
else:
|
||||
remaining[k] = v
|
||||
return query, remaining
|
||||
|
||||
|
||||
def like_filter(query, cls, search_opts):
|
||||
"""Add 'like' filters for specified columns.
|
||||
|
||||
Add a sqlalchemy 'like' filter to the query for any entry in the
|
||||
'search_opts' dict where the key is the name of a column in
|
||||
'cls' and the value is a string containing '%'.
|
||||
|
||||
This allows the value of a column to be matched
|
||||
against simple sql string patterns using LIKE and the
|
||||
'%' wildcard.
|
||||
|
||||
Return the modified query and any entries in search_opts
|
||||
whose keys do not match columns or whose values are not
|
||||
strings containing '%'.
|
||||
|
||||
:param query: a non-null query object
|
||||
:param cls: the database model class the filters will apply to
|
||||
:param search_opts: a dictionary whose key/value entries are interpreted as
|
||||
column names and search patterns
|
||||
:returns: a tuple containing the modified query and a dictionary of
|
||||
unused search_opts
|
||||
"""
|
||||
if not search_opts:
|
||||
return query, search_opts
|
||||
|
||||
remaining = {}
|
||||
for k, v in six.iteritems(search_opts):
|
||||
if isinstance(v, six.string_types) and (
|
||||
'%' in v and k in cls.__table__.columns):
|
||||
col = cls.__table__.columns[k]
|
||||
query = query.filter(col.like(v))
|
||||
else:
|
||||
remaining[k] = v
|
||||
return query, remaining
|
||||
|
||||
|
||||
def setup_db():
|
||||
try:
|
||||
engine = get_engine()
|
||||
@ -489,6 +559,37 @@ def data_source_get(context, data_source_id):
|
||||
return _data_source_get(context, get_session(), data_source_id)
|
||||
|
||||
|
||||
def data_source_count(context, **kwargs):
|
||||
"""Count DataSource objects filtered by search criteria in kwargs.
|
||||
|
||||
Entries in kwargs indicate column names and search values.
|
||||
|
||||
'in' filters will be used to search for any entries in kwargs
|
||||
that name DataSource columns and have values of type tuple. This
|
||||
allows column values to match multiple values (OR)
|
||||
|
||||
'like' filters will be used for any entries in kwargs that
|
||||
name DataSource columns and have string values containing '%'.
|
||||
This allows column values to match simple wildcards.
|
||||
|
||||
Any other entries in kwargs will be searched for using filter_by()
|
||||
"""
|
||||
query = model_query(m.DataSource, context)
|
||||
query, kwargs = in_filter(query, m.DataSource, kwargs)
|
||||
query, kwargs = like_filter(query, m.DataSource, kwargs)
|
||||
|
||||
# Use normal filter_by for remaining keys
|
||||
try:
|
||||
return query.filter_by(**kwargs).count()
|
||||
except Exception as e:
|
||||
if kwargs:
|
||||
# If kwargs is non-empty then we assume this
|
||||
# is a bad field reference. User asked for something
|
||||
# that doesn't exist, so return empty list
|
||||
return []
|
||||
raise e
|
||||
|
||||
|
||||
def data_source_get_all(context, **kwargs):
|
||||
query = model_query(m.DataSource, context)
|
||||
try:
|
||||
|
@ -21,8 +21,10 @@ import six
|
||||
|
||||
from sahara import conductor as c
|
||||
from sahara import context
|
||||
from sahara.openstack.common import uuidutils
|
||||
from sahara.plugins import base as plugin_base
|
||||
from sahara.service.edp.binary_retrievers import dispatch
|
||||
from sahara.swift import swift_helper as sw
|
||||
from sahara.utils import edp
|
||||
from sahara.utils import remote
|
||||
|
||||
@ -39,6 +41,12 @@ CONF.register_opts(opts)
|
||||
|
||||
conductor = c.API
|
||||
|
||||
# Prefix used to mark data_source name references in arg lists
|
||||
DATA_SOURCE_PREFIX = "datasource://"
|
||||
|
||||
DATA_SOURCE_SUBST_NAME = "edp.substitute_data_source_for_name"
|
||||
DATA_SOURCE_SUBST_UUID = "edp.substitute_data_source_for_uuid"
|
||||
|
||||
|
||||
def get_plugin(cluster):
|
||||
return plugin_base.PLUGINS.get_plugin(cluster.plugin_name)
|
||||
@ -94,3 +102,175 @@ def _append_slash_if_needed(path):
|
||||
if path[-1] != '/':
|
||||
path += '/'
|
||||
return path
|
||||
|
||||
|
||||
def may_contain_data_source_refs(job_configs):
|
||||
|
||||
def _check_data_source_ref_option(option):
|
||||
truth = job_configs and (
|
||||
job_configs.get('configs', {}).get(option))
|
||||
# Config values specified in the UI may be
|
||||
# passed as strings
|
||||
return truth in (True, 'True')
|
||||
|
||||
return (
|
||||
_check_data_source_ref_option(DATA_SOURCE_SUBST_NAME),
|
||||
_check_data_source_ref_option(DATA_SOURCE_SUBST_UUID))
|
||||
|
||||
|
||||
def _data_source_ref_search(job_configs, func, prune=lambda x: x):
|
||||
"""Return a list of unique values in job_configs filtered by func().
|
||||
|
||||
Loop over the 'args', 'configs' and 'params' elements in
|
||||
job_configs and return a list of all values for which
|
||||
func(value) is True.
|
||||
|
||||
Optionally provide a 'prune' function that is applied
|
||||
to values before they are added to the return value.
|
||||
"""
|
||||
args = set([prune(arg) for arg in job_configs.get(
|
||||
'args', []) if func(arg)])
|
||||
|
||||
configs = set([prune(val) for val in six.itervalues(
|
||||
job_configs.get('configs', {})) if func(val)])
|
||||
|
||||
params = set([prune(val) for val in six.itervalues(
|
||||
job_configs.get('params', {})) if func(val)])
|
||||
|
||||
return list(args | configs | params)
|
||||
|
||||
|
||||
def find_possible_data_source_refs_by_name(job_configs):
|
||||
"""Find string values in job_configs starting with 'datasource://'.
|
||||
|
||||
Loop over the 'args', 'configs', and 'params' elements of
|
||||
job_configs to find all values beginning with the prefix
|
||||
'datasource://'. Return a list of unique values with the prefix
|
||||
removed.
|
||||
|
||||
Note that for 'configs' and 'params', which are dictionaries, only
|
||||
the values are considered and the keys are not relevant.
|
||||
"""
|
||||
def startswith(arg):
|
||||
return isinstance(
|
||||
arg,
|
||||
six.string_types) and arg.startswith(DATA_SOURCE_PREFIX)
|
||||
return _data_source_ref_search(job_configs,
|
||||
startswith,
|
||||
prune=lambda x: x[len(DATA_SOURCE_PREFIX):])
|
||||
|
||||
|
||||
def find_possible_data_source_refs_by_uuid(job_configs):
|
||||
"""Find string values in job_configs which are uuids.
|
||||
|
||||
Return a list of unique values in the 'args', 'configs', and 'params'
|
||||
elements of job_configs which have the form of a uuid.
|
||||
|
||||
Note that for 'configs' and 'params', which are dictionaries, only
|
||||
the values are considered and the keys are not relevant.
|
||||
"""
|
||||
return _data_source_ref_search(job_configs, uuidutils.is_uuid_like)
|
||||
|
||||
|
||||
def _add_credentials_for_data_sources(ds_list, configs):
|
||||
|
||||
username = password = None
|
||||
for src in ds_list:
|
||||
if src.type == "swift" and hasattr(src, "credentials"):
|
||||
if "user" in src.credentials:
|
||||
username = src.credentials['user']
|
||||
if "password" in src.credentials:
|
||||
password = src.credentials['password']
|
||||
break
|
||||
|
||||
# Don't overwrite if there is already a value here
|
||||
if configs.get(sw.HADOOP_SWIFT_USERNAME, None) is None and (
|
||||
username is not None):
|
||||
configs[sw.HADOOP_SWIFT_USERNAME] = username
|
||||
if configs.get(sw.HADOOP_SWIFT_PASSWORD, None) is None and (
|
||||
password is not None):
|
||||
configs[sw.HADOOP_SWIFT_PASSWORD] = password
|
||||
|
||||
|
||||
def resolve_data_source_references(job_configs):
|
||||
"""Resolve possible data_source references in job_configs.
|
||||
|
||||
Look for any string values in the 'args', 'configs', and 'params'
|
||||
elements of job_configs which start with 'datasource://' or have
|
||||
the form of a uuid.
|
||||
|
||||
For values beginning with 'datasource://', strip off the prefix
|
||||
and search for a DataSource object with a name that matches the
|
||||
value.
|
||||
|
||||
For values having the form of a uuid, search for a DataSource object
|
||||
with an id that matches the value.
|
||||
|
||||
If a DataSource object is found for the value, replace the value
|
||||
with the URL from the DataSource object. If any DataSource objects
|
||||
are found which reference swift paths and contain credentials, set
|
||||
credential configuration values in job_configs (use the first set
|
||||
of swift credentials found).
|
||||
|
||||
If no values are resolved, return an empty list and a reference
|
||||
to job_configs.
|
||||
|
||||
If any values are resolved, return a list of the referenced
|
||||
data_source objects and a copy of job_configs with all of the
|
||||
references replaced with URLs.
|
||||
"""
|
||||
by_name, by_uuid = may_contain_data_source_refs(job_configs)
|
||||
if not (by_name or by_uuid):
|
||||
return [], job_configs
|
||||
|
||||
ctx = context.ctx()
|
||||
ds_seen = {}
|
||||
new_configs = {}
|
||||
|
||||
def _resolve(value):
|
||||
kwargs = {}
|
||||
if by_name and isinstance(
|
||||
value,
|
||||
six.string_types) and value.startswith(DATA_SOURCE_PREFIX):
|
||||
value = value[len(DATA_SOURCE_PREFIX):]
|
||||
kwargs['name'] = value
|
||||
|
||||
elif by_uuid and uuidutils.is_uuid_like(value):
|
||||
kwargs['id'] = value
|
||||
|
||||
if kwargs:
|
||||
# Name and id are both unique constraints so if there
|
||||
# is more than 1 something is really wrong
|
||||
ds = conductor.data_source_get_all(ctx, **kwargs)
|
||||
if len(ds) == 1:
|
||||
ds = ds[0]
|
||||
ds_seen[ds.id] = ds
|
||||
return ds.url
|
||||
return value
|
||||
|
||||
# Loop over configs/params/args and look up each value as a data_source.
|
||||
# If we find it, replace the value. In all cases, we've produced a
|
||||
# copy which is not a FrozenClass type and can be updated.
|
||||
new_configs['configs'] = {
|
||||
k: _resolve(v) for k, v in six.iteritems(
|
||||
job_configs.get('configs', {}))}
|
||||
new_configs['params'] = {
|
||||
k: _resolve(v) for k, v in six.iteritems(
|
||||
job_configs.get('params', {}))}
|
||||
new_configs['args'] = [_resolve(a) for a in job_configs.get('args', [])]
|
||||
|
||||
# If we didn't resolve anything we might as well return the original
|
||||
ds_seen = ds_seen.values()
|
||||
if not ds_seen:
|
||||
return [], job_configs
|
||||
|
||||
# If there are no proxy_configs and the user has not already set configs
|
||||
# for swift credentials, set those configs based on data_sources we found
|
||||
if not job_configs.get('proxy_configs'):
|
||||
_add_credentials_for_data_sources(ds_seen, new_configs['configs'])
|
||||
else:
|
||||
# we'll need to copy these, too, so job_configs is complete
|
||||
new_configs['proxy_configs'] = {
|
||||
k: v for k, v in six.iteritems(job_configs.get('proxy_configs'))}
|
||||
|
||||
return ds_seen, new_configs
|
||||
|
@ -81,9 +81,20 @@ class OozieJobEngine(base_engine.JobEngine):
|
||||
job = conductor.job_get(ctx, job_execution.job_id)
|
||||
input_source, output_source = job_utils.get_data_sources(job_execution,
|
||||
job)
|
||||
proxy_configs = job_execution.job_configs.get('proxy_configs')
|
||||
|
||||
for data_source in [input_source, output_source]:
|
||||
# Updated_job_configs will be a copy of job_execution.job_configs with
|
||||
# any name or uuid references to data_sources resolved to paths
|
||||
# assuming substitution is enabled.
|
||||
# If substitution is not enabled then updated_job_configs will
|
||||
# just be a reference to job_execution.job_configs to avoid a copy.
|
||||
# Additional_sources will be a list of any data_sources found.
|
||||
additional_sources, updated_job_configs = (
|
||||
job_utils.resolve_data_source_references(job_execution.job_configs)
|
||||
)
|
||||
|
||||
proxy_configs = updated_job_configs.get('proxy_configs')
|
||||
|
||||
for data_source in [input_source, output_source] + additional_sources:
|
||||
if data_source and data_source.type == 'hdfs':
|
||||
h.configure_cluster_for_hdfs(self.cluster, data_source)
|
||||
break
|
||||
@ -99,7 +110,8 @@ class OozieJobEngine(base_engine.JobEngine):
|
||||
proxy_configs)
|
||||
|
||||
wf_xml = workflow_factory.get_workflow_xml(
|
||||
job, self.cluster, job_execution, input_source, output_source,
|
||||
job, self.cluster, updated_job_configs,
|
||||
input_source, output_source,
|
||||
hdfs_user)
|
||||
|
||||
path_to_workflow = self._upload_workflow_file(oozie_server, wf_dir,
|
||||
|
@ -140,14 +140,14 @@ class PigFactory(BaseFactory):
|
||||
def get_script_name(self, job):
|
||||
return conductor.job_main_name(context.ctx(), job)
|
||||
|
||||
def get_workflow_xml(self, cluster, execution, input_data, output_data,
|
||||
def get_workflow_xml(self, cluster, job_configs, input_data, output_data,
|
||||
hdfs_user):
|
||||
proxy_configs = execution.job_configs.get('proxy_configs')
|
||||
proxy_configs = job_configs.get('proxy_configs')
|
||||
job_dict = {'configs': self.get_configs(input_data, output_data,
|
||||
proxy_configs),
|
||||
'params': self.get_params(input_data, output_data),
|
||||
'args': []}
|
||||
self.update_job_dict(job_dict, execution.job_configs)
|
||||
self.update_job_dict(job_dict, job_configs)
|
||||
creator = pig_workflow.PigWorkflowCreator()
|
||||
creator.build_workflow_xml(self.name,
|
||||
configuration=job_dict['configs'],
|
||||
@ -165,13 +165,13 @@ class HiveFactory(BaseFactory):
|
||||
def get_script_name(self, job):
|
||||
return conductor.job_main_name(context.ctx(), job)
|
||||
|
||||
def get_workflow_xml(self, cluster, execution, input_data, output_data,
|
||||
def get_workflow_xml(self, cluster, job_configs, input_data, output_data,
|
||||
hdfs_user):
|
||||
proxy_configs = execution.job_configs.get('proxy_configs')
|
||||
proxy_configs = job_configs.get('proxy_configs')
|
||||
job_dict = {'configs': self.get_configs(input_data, output_data,
|
||||
proxy_configs),
|
||||
'params': self.get_params(input_data, output_data)}
|
||||
self.update_job_dict(job_dict, execution.job_configs)
|
||||
self.update_job_dict(job_dict, job_configs)
|
||||
|
||||
creator = hive_workflow.HiveWorkflowCreator()
|
||||
creator.build_workflow_xml(self.name,
|
||||
@ -196,12 +196,12 @@ class MapReduceFactory(BaseFactory):
|
||||
return dict((k[len(prefix):], v) for (k, v) in six.iteritems(
|
||||
job_dict['edp_configs']) if k.startswith(prefix))
|
||||
|
||||
def get_workflow_xml(self, cluster, execution, input_data, output_data,
|
||||
def get_workflow_xml(self, cluster, job_configs, input_data, output_data,
|
||||
hdfs_user):
|
||||
proxy_configs = execution.job_configs.get('proxy_configs')
|
||||
proxy_configs = job_configs.get('proxy_configs')
|
||||
job_dict = {'configs': self.get_configs(input_data, output_data,
|
||||
proxy_configs)}
|
||||
self.update_job_dict(job_dict, execution.job_configs)
|
||||
self.update_job_dict(job_dict, job_configs)
|
||||
creator = mapreduce_workflow.MapReduceWorkFlowCreator()
|
||||
creator.build_workflow_xml(configuration=job_dict['configs'],
|
||||
streaming=self._get_streaming(job_dict))
|
||||
@ -230,11 +230,11 @@ class JavaFactory(BaseFactory):
|
||||
|
||||
return configs
|
||||
|
||||
def get_workflow_xml(self, cluster, execution, *args, **kwargs):
|
||||
proxy_configs = execution.job_configs.get('proxy_configs')
|
||||
def get_workflow_xml(self, cluster, job_configs, *args, **kwargs):
|
||||
proxy_configs = job_configs.get('proxy_configs')
|
||||
job_dict = {'configs': self.get_configs(proxy_configs=proxy_configs),
|
||||
'args': []}
|
||||
self.update_job_dict(job_dict, execution.job_configs)
|
||||
self.update_job_dict(job_dict, job_configs)
|
||||
|
||||
main_class, java_opts = self._get_java_configs(job_dict)
|
||||
creator = java_workflow.JavaWorkflowCreator()
|
||||
@ -264,9 +264,9 @@ def _get_creator(job):
|
||||
return type_map[job.type]()
|
||||
|
||||
|
||||
def get_workflow_xml(job, cluster, execution, *args, **kwargs):
|
||||
def get_workflow_xml(job, cluster, job_configs, *args, **kwargs):
|
||||
return _get_creator(job).get_workflow_xml(
|
||||
cluster, execution, *args, **kwargs)
|
||||
cluster, job_configs, *args, **kwargs)
|
||||
|
||||
|
||||
def get_possible_job_config(job_type):
|
||||
|
@ -114,7 +114,11 @@ class SparkJobEngine(base_engine.JobEngine):
|
||||
ctx = context.ctx()
|
||||
job = conductor.job_get(ctx, job_execution.job_id)
|
||||
|
||||
proxy_configs = job_execution.job_configs.get('proxy_configs')
|
||||
additional_sources, updated_job_configs = (
|
||||
job_utils.resolve_data_source_references(job_execution.job_configs)
|
||||
)
|
||||
|
||||
proxy_configs = updated_job_configs.get('proxy_configs')
|
||||
|
||||
# We'll always run the driver program on the master
|
||||
master = plugin_utils.get_instance(self.cluster, "master")
|
||||
@ -150,11 +154,11 @@ class SparkJobEngine(base_engine.JobEngine):
|
||||
self.cluster),
|
||||
"bin/spark-submit")
|
||||
|
||||
job_class = job_execution.job_configs.configs["edp.java.main_class"]
|
||||
job_class = updated_job_configs['configs']["edp.java.main_class"]
|
||||
|
||||
# TODO(tmckay): we need to clean up wf_dirs on long running clusters
|
||||
# TODO(tmckay): probably allow for general options to spark-submit
|
||||
args = " ".join(job_execution.job_configs.get('args', []))
|
||||
args = " ".join(updated_job_configs.get('args', []))
|
||||
|
||||
# The redirects of stdout and stderr will preserve output in the wf_dir
|
||||
cmd = "%s %s --class %s %s --master spark://%s:%s %s" % (
|
||||
|
@ -21,7 +21,7 @@ import uuid
|
||||
import fixtures
|
||||
import six
|
||||
|
||||
from sahara.swift import swift_helper as sw
|
||||
from sahara.service.edp import job_utils
|
||||
from sahara.tests.integration.tests import base
|
||||
from sahara.utils import edp
|
||||
|
||||
@ -204,17 +204,13 @@ class EDPTest(base.ITestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def _add_swift_configs(self, configs):
|
||||
def _enable_substitution(self, configs):
|
||||
|
||||
if "configs" not in configs:
|
||||
configs["configs"] = {}
|
||||
|
||||
if sw.HADOOP_SWIFT_USERNAME not in configs["configs"]:
|
||||
configs["configs"][
|
||||
sw.HADOOP_SWIFT_USERNAME] = self.common_config.OS_USERNAME
|
||||
if sw.HADOOP_SWIFT_PASSWORD not in configs["configs"]:
|
||||
configs["configs"][
|
||||
sw.HADOOP_SWIFT_PASSWORD] = self.common_config.OS_PASSWORD
|
||||
configs['configs'][job_utils.DATA_SOURCE_SUBST_NAME] = True
|
||||
configs['configs'][job_utils.DATA_SOURCE_SUBST_UUID] = True
|
||||
|
||||
@base.skip_test('SKIP_EDP_TEST', 'Test for EDP was skipped.')
|
||||
def check_edp_hive(self):
|
||||
@ -279,17 +275,14 @@ class EDPTest(base.ITestCase):
|
||||
output_type = "swift"
|
||||
output_url = 'swift://%s.sahara/output' % container_name
|
||||
|
||||
# Java jobs don't use data sources. Input/output paths must
|
||||
# be passed as args with corresponding username/password configs
|
||||
if not edp.compare_job_type(job_type,
|
||||
edp.JOB_TYPE_JAVA,
|
||||
edp.JOB_TYPE_SPARK):
|
||||
input_id = self._create_data_source(
|
||||
'input-%s' % str(uuid.uuid4())[:8], 'swift',
|
||||
swift_input_url)
|
||||
output_id = self._create_data_source(
|
||||
'output-%s' % str(uuid.uuid4())[:8], output_type,
|
||||
output_url)
|
||||
input_name = 'input-%s' % str(uuid.uuid4())[:8]
|
||||
input_id = self._create_data_source(input_name,
|
||||
'swift', swift_input_url)
|
||||
|
||||
output_name = 'output-%s' % str(uuid.uuid4())[:8]
|
||||
output_id = self._create_data_source(output_name,
|
||||
output_type,
|
||||
output_url)
|
||||
|
||||
if job_data_list:
|
||||
if swift_binaries:
|
||||
@ -329,11 +322,13 @@ class EDPTest(base.ITestCase):
|
||||
# if the caller has requested it...
|
||||
if edp.compare_job_type(
|
||||
job_type, edp.JOB_TYPE_JAVA) and pass_input_output_args:
|
||||
self._add_swift_configs(configs)
|
||||
self._enable_substitution(configs)
|
||||
input_arg = job_utils.DATA_SOURCE_PREFIX + input_name
|
||||
output_arg = output_id
|
||||
if "args" in configs:
|
||||
configs["args"].extend([swift_input_url, output_url])
|
||||
configs["args"].extend([input_arg, output_arg])
|
||||
else:
|
||||
configs["args"] = [swift_input_url, output_url]
|
||||
configs["args"] = [input_arg, output_arg]
|
||||
|
||||
job_execution = self.sahara.job_executions.create(
|
||||
job_id, self.cluster_id, input_id, output_id,
|
||||
|
@ -158,6 +158,57 @@ class DataSourceTest(test_base.ConductorManagerTestCase):
|
||||
lst = self.api.data_source_get_all(ctx, **{'badfield': 'somevalue'})
|
||||
self.assertEqual(len(lst), 0)
|
||||
|
||||
def test_data_source_count_in(self):
|
||||
ctx = context.ctx()
|
||||
ctx.tenant_id = SAMPLE_DATA_SOURCE['tenant_id']
|
||||
src = copy.copy(SAMPLE_DATA_SOURCE)
|
||||
self.api.data_source_create(ctx, src)
|
||||
|
||||
cnt = self.api.data_source_count(ctx, name='ngt_test')
|
||||
self.assertEqual(cnt, 1)
|
||||
|
||||
cnt = self.api.data_source_count(ctx, name=('ngt_test',
|
||||
'test2', 'test3'))
|
||||
self.assertEqual(cnt, 1)
|
||||
|
||||
cnt = self.api.data_source_count(ctx, name=('test1',
|
||||
'test2', 'test3'))
|
||||
self.assertEqual(cnt, 0)
|
||||
|
||||
lst = self.api.data_source_get_all(ctx, name='ngt_test')
|
||||
myid = lst[0]['id']
|
||||
cnt = self.api.data_source_count(ctx,
|
||||
name=('ngt_test', 'test2', 'test3'),
|
||||
id=myid)
|
||||
self.assertEqual(cnt, 1)
|
||||
|
||||
cnt = self.api.data_source_count(ctx,
|
||||
name=('ngt_test', 'test2', 'test3'),
|
||||
id=(myid, '2'))
|
||||
self.assertEqual(cnt, 1)
|
||||
|
||||
def test_data_source_count_like(self):
|
||||
ctx = context.ctx()
|
||||
ctx.tenant_id = SAMPLE_DATA_SOURCE['tenant_id']
|
||||
src = copy.copy(SAMPLE_DATA_SOURCE)
|
||||
self.api.data_source_create(ctx, src)
|
||||
|
||||
cnt = self.api.data_source_count(ctx, name='ngt_test')
|
||||
self.assertEqual(cnt, 1)
|
||||
|
||||
cnt = self.api.data_source_count(ctx, name='ngt%')
|
||||
self.assertEqual(cnt, 1)
|
||||
|
||||
cnt = self.api.data_source_count(ctx,
|
||||
name=('ngt_test',),
|
||||
url='localhost%')
|
||||
self.assertEqual(cnt, 1)
|
||||
|
||||
cnt = self.api.data_source_count(ctx,
|
||||
name=('ngt_test',),
|
||||
url='localhost')
|
||||
self.assertEqual(cnt, 0)
|
||||
|
||||
|
||||
class JobExecutionTest(test_base.ConductorManagerTestCase):
|
||||
def test_crud_operation_create_list_delete_update(self):
|
||||
|
@ -68,7 +68,7 @@ def create_cluster(plugin_name='vanilla', hadoop_version='1.2.1'):
|
||||
return cluster
|
||||
|
||||
|
||||
def create_data_source(url):
|
||||
def create_data_source(url, name=None, id=None):
|
||||
data_source = mock.Mock()
|
||||
data_source.url = url
|
||||
if url.startswith("swift"):
|
||||
@ -77,6 +77,10 @@ def create_data_source(url):
|
||||
'password': 'admin1'}
|
||||
elif url.startswith("hdfs"):
|
||||
data_source.type = "hdfs"
|
||||
if name is not None:
|
||||
data_source.name = name
|
||||
if id is not None:
|
||||
data_source.id = id
|
||||
return data_source
|
||||
|
||||
|
||||
|
@ -323,11 +323,6 @@ class TestSpark(base.SaharaTestCase):
|
||||
upload_job_files, get_config_value, get_remote,
|
||||
job_exec_get):
|
||||
|
||||
def fix_get(field, default=None):
|
||||
if field == "args":
|
||||
return ["input_arg", "output_arg"]
|
||||
return default
|
||||
|
||||
eng = se.SparkJobEngine("cluster")
|
||||
|
||||
job = mock.Mock()
|
||||
@ -335,9 +330,9 @@ class TestSpark(base.SaharaTestCase):
|
||||
job_get.return_value = job
|
||||
|
||||
job_exec = mock.Mock()
|
||||
job_exec.job_configs.configs = {"edp.java.main_class":
|
||||
"org.me.myclass"}
|
||||
job_exec.job_configs.get = fix_get
|
||||
job_exec.job_configs = {'configs': {"edp.java.main_class":
|
||||
"org.me.myclass"},
|
||||
'args': ['input_arg', 'output_arg']}
|
||||
|
||||
master = mock.Mock()
|
||||
get_instance.return_value = master
|
||||
|
@ -120,8 +120,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
|
||||
output_data = u.create_data_source('swift://ex/o')
|
||||
|
||||
res = workflow_factory.get_workflow_xml(
|
||||
job, u.create_cluster(), job_exec, input_data, output_data,
|
||||
'hadoop')
|
||||
job, u.create_cluster(), job_exec.job_configs,
|
||||
input_data, output_data, 'hadoop')
|
||||
|
||||
self.assertIn("""
|
||||
<param>INPUT=swift://ex.sahara/i</param>
|
||||
@ -147,8 +147,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
|
||||
job, job_exec = u.create_job_exec(edp.JOB_TYPE_PIG, proxy=True)
|
||||
|
||||
res = workflow_factory.get_workflow_xml(
|
||||
job, u.create_cluster(), job_exec, input_data, output_data,
|
||||
'hadoop')
|
||||
job, u.create_cluster(), job_exec.job_configs,
|
||||
input_data, output_data, 'hadoop')
|
||||
|
||||
self.assertIn("""
|
||||
<configuration>
|
||||
@ -181,8 +181,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
|
||||
output_data = u.create_data_source('hdfs://user/hadoop/out')
|
||||
|
||||
res = workflow_factory.get_workflow_xml(
|
||||
job, u.create_cluster(), job_exec, input_data, output_data,
|
||||
'hadoop')
|
||||
job, u.create_cluster(), job_exec.job_configs,
|
||||
input_data, output_data, 'hadoop')
|
||||
|
||||
self.assertIn("""
|
||||
<configuration>
|
||||
@ -200,8 +200,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
|
||||
output_data = u.create_data_source('swift://ex/o')
|
||||
|
||||
res = workflow_factory.get_workflow_xml(
|
||||
job, u.create_cluster(), job_exec, input_data, output_data,
|
||||
'hadoop')
|
||||
job, u.create_cluster(), job_exec.job_configs,
|
||||
input_data, output_data, 'hadoop')
|
||||
|
||||
self.assertIn("""
|
||||
<configuration>
|
||||
@ -221,8 +221,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
|
||||
output_data = u.create_data_source('hdfs://user/hadoop/out')
|
||||
|
||||
res = workflow_factory.get_workflow_xml(
|
||||
job, u.create_cluster(), job_exec, input_data, output_data,
|
||||
'hadoop')
|
||||
job, u.create_cluster(), job_exec.job_configs,
|
||||
input_data, output_data, 'hadoop')
|
||||
|
||||
self.assertIn("""
|
||||
<configuration>
|
||||
@ -246,8 +246,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
|
||||
output_data = u.create_data_source('swift://ex/o')
|
||||
|
||||
res = workflow_factory.get_workflow_xml(
|
||||
job, u.create_cluster(), job_exec, input_data, output_data,
|
||||
'hadoop')
|
||||
job, u.create_cluster(), job_exec.job_configs,
|
||||
input_data, output_data, 'hadoop')
|
||||
|
||||
if streaming:
|
||||
self.assertIn("""
|
||||
@ -288,8 +288,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
|
||||
job, job_exec = u.create_job_exec(job_type, proxy=True)
|
||||
|
||||
res = workflow_factory.get_workflow_xml(
|
||||
job, u.create_cluster(), job_exec, input_data, output_data,
|
||||
'hadoop')
|
||||
job, u.create_cluster(), job_exec.job_configs,
|
||||
input_data, output_data, 'hadoop')
|
||||
|
||||
self.assertIn("""
|
||||
<property>
|
||||
@ -331,7 +331,7 @@ class TestJobManager(base.SaharaWithDbTestCase):
|
||||
|
||||
job, job_exec = u.create_job_exec(edp.JOB_TYPE_JAVA, configs)
|
||||
res = workflow_factory.get_workflow_xml(
|
||||
job, u.create_cluster(), job_exec)
|
||||
job, u.create_cluster(), job_exec.job_configs)
|
||||
|
||||
self.assertIn("""
|
||||
<configuration>
|
||||
@ -361,7 +361,7 @@ class TestJobManager(base.SaharaWithDbTestCase):
|
||||
job, job_exec = u.create_job_exec(edp.JOB_TYPE_JAVA, configs,
|
||||
proxy=True)
|
||||
res = workflow_factory.get_workflow_xml(job, u.create_cluster(),
|
||||
job_exec)
|
||||
job_exec.job_configs)
|
||||
|
||||
self.assertIn("""
|
||||
<configuration>
|
||||
@ -397,8 +397,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
|
||||
output_data = u.create_data_source('swift://ex/o')
|
||||
|
||||
res = workflow_factory.get_workflow_xml(
|
||||
job, u.create_cluster(), job_exec, input_data, output_data,
|
||||
'hadoop')
|
||||
job, u.create_cluster(), job_exec.job_configs,
|
||||
input_data, output_data, 'hadoop')
|
||||
|
||||
doc = xml.parseString(res)
|
||||
hive = doc.getElementsByTagName('hive')[0]
|
||||
@ -425,8 +425,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
|
||||
job, job_exec = u.create_job_exec(edp.JOB_TYPE_HIVE, proxy=True)
|
||||
|
||||
res = workflow_factory.get_workflow_xml(
|
||||
job, u.create_cluster(), job_exec, input_data, output_data,
|
||||
'hadoop')
|
||||
job, u.create_cluster(), job_exec.job_configs,
|
||||
input_data, output_data, 'hadoop')
|
||||
|
||||
doc = xml.parseString(res)
|
||||
hive = doc.getElementsByTagName('hive')[0]
|
||||
|
249
sahara/tests/unit/service/edp/test_job_utils.py
Normal file
249
sahara/tests/unit/service/edp/test_job_utils.py
Normal file
@ -0,0 +1,249 @@
|
||||
# Copyright (c) 2013 Mirantis Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import uuid
|
||||
|
||||
import mock
|
||||
import six
|
||||
import testtools
|
||||
|
||||
from sahara import conductor as cond
|
||||
from sahara.service.edp import job_utils
|
||||
from sahara.tests.unit.service.edp import edp_test_utils as u
|
||||
|
||||
conductor = cond.API
|
||||
|
||||
|
||||
class JobUtilsTestCase(testtools.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(JobUtilsTestCase, self).setUp()
|
||||
|
||||
def test_args_may_contain_data_sources(self):
|
||||
job_configs = None
|
||||
|
||||
# No configs, default false
|
||||
by_name, by_uuid = job_utils.may_contain_data_source_refs(job_configs)
|
||||
self.assertFalse(by_name | by_uuid)
|
||||
|
||||
# Empty configs, default false
|
||||
job_configs = {'configs': {}}
|
||||
by_name, by_uuid = job_utils.may_contain_data_source_refs(job_configs)
|
||||
self.assertFalse(by_name | by_uuid)
|
||||
|
||||
job_configs['configs'] = {job_utils.DATA_SOURCE_SUBST_NAME: True,
|
||||
job_utils.DATA_SOURCE_SUBST_UUID: True}
|
||||
by_name, by_uuid = job_utils.may_contain_data_source_refs(job_configs)
|
||||
self.assertTrue(by_name & by_uuid)
|
||||
|
||||
job_configs['configs'][job_utils.DATA_SOURCE_SUBST_NAME] = False
|
||||
by_name, by_uuid = job_utils.may_contain_data_source_refs(job_configs)
|
||||
self.assertFalse(by_name)
|
||||
self.assertTrue(by_uuid)
|
||||
|
||||
job_configs['configs'][job_utils.DATA_SOURCE_SUBST_UUID] = False
|
||||
by_name, by_uuid = job_utils.may_contain_data_source_refs(job_configs)
|
||||
self.assertFalse(by_name | by_uuid)
|
||||
|
||||
job_configs['configs'] = {job_utils.DATA_SOURCE_SUBST_NAME: 'True',
|
||||
job_utils.DATA_SOURCE_SUBST_UUID: 'Fish'}
|
||||
by_name, by_uuid = job_utils.may_contain_data_source_refs(job_configs)
|
||||
self.assertTrue(by_name)
|
||||
self.assertFalse(by_uuid)
|
||||
|
||||
def test_find_possible_data_source_refs_by_name(self):
|
||||
id = six.text_type(uuid.uuid4())
|
||||
job_configs = {}
|
||||
self.assertEqual([],
|
||||
job_utils.find_possible_data_source_refs_by_name(
|
||||
job_configs))
|
||||
|
||||
name_ref = job_utils.DATA_SOURCE_PREFIX+'name'
|
||||
name_ref2 = name_ref+'2'
|
||||
|
||||
job_configs = {'args': ['first', id],
|
||||
'configs': {'config': 'value'},
|
||||
'params': {'param': 'value'}}
|
||||
self.assertEqual([],
|
||||
job_utils.find_possible_data_source_refs_by_name(
|
||||
job_configs))
|
||||
|
||||
job_configs = {'args': [name_ref, id],
|
||||
'configs': {'config': 'value'},
|
||||
'params': {'param': 'value'}}
|
||||
self.assertEqual(
|
||||
['name'],
|
||||
job_utils.find_possible_data_source_refs_by_name(job_configs))
|
||||
|
||||
job_configs = {'args': ['first', id],
|
||||
'configs': {'config': name_ref},
|
||||
'params': {'param': 'value'}}
|
||||
self.assertEqual(
|
||||
['name'],
|
||||
job_utils.find_possible_data_source_refs_by_name(job_configs))
|
||||
|
||||
job_configs = {'args': ['first', id],
|
||||
'configs': {'config': 'value'},
|
||||
'params': {'param': name_ref}}
|
||||
self.assertEqual(
|
||||
['name'],
|
||||
job_utils.find_possible_data_source_refs_by_name(job_configs))
|
||||
|
||||
job_configs = {'args': [name_ref, name_ref2, id],
|
||||
'configs': {'config': name_ref},
|
||||
'params': {'param': name_ref}}
|
||||
self.assertItemsEqual(
|
||||
['name', 'name2'],
|
||||
job_utils.find_possible_data_source_refs_by_name(job_configs))
|
||||
|
||||
def test_find_possible_data_source_refs_by_uuid(self):
|
||||
job_configs = {}
|
||||
|
||||
name_ref = job_utils.DATA_SOURCE_PREFIX+'name'
|
||||
|
||||
self.assertEqual([],
|
||||
job_utils.find_possible_data_source_refs_by_uuid(
|
||||
job_configs))
|
||||
|
||||
id = six.text_type(uuid.uuid4())
|
||||
job_configs = {'args': ['first', name_ref],
|
||||
'configs': {'config': 'value'},
|
||||
'params': {'param': 'value'}}
|
||||
self.assertEqual([],
|
||||
job_utils.find_possible_data_source_refs_by_uuid(
|
||||
job_configs))
|
||||
|
||||
job_configs = {'args': [id, name_ref],
|
||||
'configs': {'config': 'value'},
|
||||
'params': {'param': 'value'}}
|
||||
self.assertEqual(
|
||||
[id],
|
||||
job_utils.find_possible_data_source_refs_by_uuid(job_configs))
|
||||
|
||||
job_configs = {'args': ['first', name_ref],
|
||||
'configs': {'config': id},
|
||||
'params': {'param': 'value'}}
|
||||
self.assertEqual(
|
||||
[id],
|
||||
job_utils.find_possible_data_source_refs_by_uuid(job_configs))
|
||||
|
||||
job_configs = {'args': ['first', name_ref],
|
||||
'configs': {'config': 'value'},
|
||||
'params': {'param': id}}
|
||||
self.assertEqual(
|
||||
[id],
|
||||
job_utils.find_possible_data_source_refs_by_uuid(job_configs))
|
||||
|
||||
id2 = six.text_type(uuid.uuid4())
|
||||
job_configs = {'args': [id, id2, name_ref],
|
||||
'configs': {'config': id},
|
||||
'params': {'param': id}}
|
||||
self.assertItemsEqual([id, id2],
|
||||
job_utils.find_possible_data_source_refs_by_uuid(
|
||||
job_configs))
|
||||
|
||||
@mock.patch('sahara.context.ctx')
|
||||
@mock.patch('sahara.conductor.API.data_source_get_all')
|
||||
def test_resolve_data_source_refs(self, data_source_get_all, ctx):
|
||||
|
||||
ctx.return_value = 'dummy'
|
||||
|
||||
name_ref = job_utils.DATA_SOURCE_PREFIX+'input'
|
||||
|
||||
input = u.create_data_source("swift://container/input",
|
||||
name="input",
|
||||
id=six.text_type(uuid.uuid4()))
|
||||
|
||||
output = u.create_data_source("swift://container/output",
|
||||
name="output",
|
||||
id=six.text_type(uuid.uuid4()))
|
||||
|
||||
by_name = {'input': input,
|
||||
'output': output}
|
||||
|
||||
by_id = {input.id: input,
|
||||
output.id: output}
|
||||
|
||||
# Pretend to be the database
|
||||
def _get_all(ctx, **kwargs):
|
||||
name = kwargs.get('name')
|
||||
if name in by_name:
|
||||
name_list = [by_name[name]]
|
||||
else:
|
||||
name_list = []
|
||||
|
||||
id = kwargs.get('id')
|
||||
if id in by_id:
|
||||
id_list = [by_id[id]]
|
||||
else:
|
||||
id_list = []
|
||||
return list(set(name_list + id_list))
|
||||
|
||||
data_source_get_all.side_effect = _get_all
|
||||
|
||||
job_configs = {
|
||||
'configs': {
|
||||
job_utils.DATA_SOURCE_SUBST_NAME: True,
|
||||
job_utils.DATA_SOURCE_SUBST_UUID: True},
|
||||
'args': [name_ref, output.id, input.id]}
|
||||
|
||||
ds, nc = job_utils.resolve_data_source_references(job_configs)
|
||||
self.assertEqual(len(ds), 2)
|
||||
self.assertEqual(nc['args'], [input.url, output.url, input.url])
|
||||
# Swift configs should be filled in since they were blank
|
||||
self.assertEqual(nc['configs']['fs.swift.service.sahara.username'],
|
||||
input.credentials['user'])
|
||||
self.assertEqual(nc['configs']['fs.swift.service.sahara.password'],
|
||||
input.credentials['password'])
|
||||
|
||||
job_configs['configs'] = {'fs.swift.service.sahara.username': 'sam',
|
||||
'fs.swift.service.sahara.password': 'gamgee',
|
||||
job_utils.DATA_SOURCE_SUBST_NAME: False,
|
||||
job_utils.DATA_SOURCE_SUBST_UUID: True}
|
||||
ds, nc = job_utils.resolve_data_source_references(job_configs)
|
||||
self.assertEqual(len(ds), 2)
|
||||
self.assertEqual(nc['args'], [name_ref, output.url, input.url])
|
||||
# Swift configs should not be overwritten
|
||||
self.assertEqual(nc['configs'], job_configs['configs'])
|
||||
|
||||
job_configs['configs'] = {job_utils.DATA_SOURCE_SUBST_NAME: True,
|
||||
job_utils.DATA_SOURCE_SUBST_UUID: False}
|
||||
job_configs['proxy_configs'] = {'proxy_username': 'john',
|
||||
'proxy_password': 'smith',
|
||||
'proxy_trust_id': 'trustme'}
|
||||
ds, nc = job_utils.resolve_data_source_references(job_configs)
|
||||
self.assertEqual(len(ds), 1)
|
||||
self.assertEqual(nc['args'], [input.url, output.id, input.id])
|
||||
|
||||
# Swift configs should be empty and proxy configs should be preserved
|
||||
self.assertEqual(nc['configs'], job_configs['configs'])
|
||||
self.assertEqual(nc['proxy_configs'], job_configs['proxy_configs'])
|
||||
|
||||
# Substitution not enabled
|
||||
job_configs['configs'] = {job_utils.DATA_SOURCE_SUBST_NAME: False,
|
||||
job_utils.DATA_SOURCE_SUBST_UUID: False}
|
||||
ds, nc = job_utils.resolve_data_source_references(job_configs)
|
||||
self.assertEqual(len(ds), 0)
|
||||
self.assertEqual(nc['args'], job_configs['args'])
|
||||
self.assertEqual(nc['configs'], job_configs['configs'])
|
||||
|
||||
# Substitution enabled but no values to modify
|
||||
job_configs['configs'] = {job_utils.DATA_SOURCE_SUBST_NAME: True,
|
||||
job_utils.DATA_SOURCE_SUBST_UUID: True}
|
||||
job_configs['args'] = ['val1', 'val2', 'val3']
|
||||
ds, nc = job_utils.resolve_data_source_references(job_configs)
|
||||
self.assertEqual(len(ds), 0)
|
||||
self.assertEqual(nc['args'], job_configs['args'])
|
||||
self.assertEqual(nc['configs'], job_configs['configs'])
|
@ -13,8 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import mock
|
||||
import uuid
|
||||
|
||||
import mock
|
||||
import six
|
||||
|
||||
from sahara.service.edp import job_utils
|
||||
from sahara.tests.unit import base
|
||||
from sahara.utils import proxy as p
|
||||
|
||||
@ -25,7 +29,14 @@ class TestProxyUtils(base.SaharaWithDbTestCase):
|
||||
|
||||
@mock.patch('sahara.conductor.API.job_get')
|
||||
@mock.patch('sahara.conductor.API.data_source_get')
|
||||
def test_job_execution_requires_proxy_user(self, data_source, job):
|
||||
@mock.patch('sahara.conductor.API.data_source_count')
|
||||
@mock.patch('sahara.context.ctx')
|
||||
def test_job_execution_requires_proxy_user(self,
|
||||
ctx,
|
||||
data_source_count,
|
||||
data_source,
|
||||
job):
|
||||
|
||||
self.override_config('use_domain_for_proxy_users', True)
|
||||
job_execution = mock.Mock(input_id=1,
|
||||
output_id=2,
|
||||
@ -44,8 +55,47 @@ class TestProxyUtils(base.SaharaWithDbTestCase):
|
||||
libs=[mock.Mock(url='swift://container/object')])
|
||||
self.assertTrue(p.job_execution_requires_proxy_user(job_execution))
|
||||
|
||||
job_execution.job_configs['args'] = ['swift://container/object']
|
||||
job_execution.job_configs = {'args': ['swift://container/object']}
|
||||
job.return_value = mock.Mock(
|
||||
mains=[],
|
||||
libs=[])
|
||||
self.assertTrue(p.job_execution_requires_proxy_user(job_execution))
|
||||
|
||||
job_execution.job_configs = {
|
||||
'configs': {'key': 'swift://container/object'}}
|
||||
self.assertTrue(p.job_execution_requires_proxy_user(job_execution))
|
||||
|
||||
job_execution.job_configs = {
|
||||
'params': {'key': 'swift://container/object'}}
|
||||
self.assertTrue(p.job_execution_requires_proxy_user(job_execution))
|
||||
|
||||
data_source_count.return_value = 0
|
||||
job_execution.job_configs = {
|
||||
'configs': {job_utils.DATA_SOURCE_SUBST_NAME: True}}
|
||||
job.return_value = mock.Mock(
|
||||
mains=[],
|
||||
libs=[])
|
||||
self.assertFalse(p.job_execution_requires_proxy_user(job_execution))
|
||||
|
||||
ctx.return_value = 'dummy'
|
||||
data_source_count.return_value = 1
|
||||
job_execution.job_configs = {
|
||||
'configs': {job_utils.DATA_SOURCE_SUBST_NAME: True},
|
||||
'args': [job_utils.DATA_SOURCE_PREFIX+'somevalue']}
|
||||
self.assertTrue(p.job_execution_requires_proxy_user(job_execution))
|
||||
data_source_count.assert_called_with('dummy',
|
||||
name=('somevalue',),
|
||||
url='swift://%')
|
||||
data_source_count.reset_mock()
|
||||
data_source_count.return_value = 1
|
||||
myid = six.text_type(uuid.uuid4())
|
||||
job_execution.job_configs = {
|
||||
'configs': {job_utils.DATA_SOURCE_SUBST_UUID: True},
|
||||
'args': [myid]}
|
||||
job.return_value = mock.Mock(
|
||||
mains=[],
|
||||
libs=[])
|
||||
self.assertTrue(p.job_execution_requires_proxy_user(job_execution))
|
||||
data_source_count.assert_called_with('dummy',
|
||||
id=(myid,),
|
||||
url='swift://%')
|
||||
|
@ -23,6 +23,7 @@ from sahara import context
|
||||
from sahara import exceptions as ex
|
||||
from sahara.i18n import _
|
||||
from sahara.openstack.common import log as logging
|
||||
from sahara.service.edp import job_utils
|
||||
from sahara.service import trusts as t
|
||||
from sahara.swift import utils as su
|
||||
from sahara.utils.openstack import keystone as k
|
||||
@ -180,27 +181,63 @@ def domain_for_proxy():
|
||||
|
||||
def job_execution_requires_proxy_user(job_execution):
|
||||
'''Returns True if the job execution requires a proxy user.'''
|
||||
|
||||
def _check_values(values):
|
||||
return any(value.startswith(
|
||||
su.SWIFT_INTERNAL_PREFIX) for value in values if (
|
||||
isinstance(value, six.string_types)))
|
||||
|
||||
if CONF.use_domain_for_proxy_users is False:
|
||||
return False
|
||||
input_ds = conductor.data_source_get(context.ctx(),
|
||||
job_execution.input_id)
|
||||
if input_ds and input_ds.url.startswith(su.SWIFT_INTERNAL_PREFIX):
|
||||
return True
|
||||
output_ds = conductor.data_source_get(context.ctx(),
|
||||
job_execution.output_id)
|
||||
if output_ds and output_ds.url.startswith(su.SWIFT_INTERNAL_PREFIX):
|
||||
return True
|
||||
if job_execution.job_configs.get('args'):
|
||||
for arg in job_execution.job_configs['args']:
|
||||
if arg.startswith(su.SWIFT_INTERNAL_PREFIX):
|
||||
return True
|
||||
|
||||
paths = [conductor.data_source_get(context.ctx(), job_execution.output_id),
|
||||
conductor.data_source_get(context.ctx(), job_execution.input_id)]
|
||||
if _check_values(ds.url for ds in paths if ds):
|
||||
return True
|
||||
|
||||
if _check_values(six.itervalues(
|
||||
job_execution.job_configs.get('configs', {}))):
|
||||
return True
|
||||
|
||||
if _check_values(six.itervalues(
|
||||
job_execution.job_configs.get('params', {}))):
|
||||
return True
|
||||
|
||||
if _check_values(job_execution.job_configs.get('args', [])):
|
||||
return True
|
||||
|
||||
job = conductor.job_get(context.ctx(), job_execution.job_id)
|
||||
for main in job.mains:
|
||||
if main.url.startswith(su.SWIFT_INTERNAL_PREFIX):
|
||||
if _check_values(main.url for main in job.mains):
|
||||
return True
|
||||
|
||||
if _check_values(lib.url for lib in job.libs):
|
||||
return True
|
||||
|
||||
# We did the simple checks, now if data_source referencing is
|
||||
# enabled and we have values that could be a name or uuid,
|
||||
# query for data_sources that match and contain a swift path
|
||||
by_name, by_uuid = job_utils.may_contain_data_source_refs(
|
||||
job_execution.job_configs)
|
||||
if by_name:
|
||||
names = tuple(job_utils.find_possible_data_source_refs_by_name(
|
||||
job_execution.job_configs))
|
||||
# do a query here for name in names and path starts with swift-prefix
|
||||
if names and conductor.data_source_count(
|
||||
context.ctx(),
|
||||
name=names,
|
||||
url=su.SWIFT_INTERNAL_PREFIX+'%') > 0:
|
||||
return True
|
||||
for lib in job.libs:
|
||||
if lib.url.startswith(su.SWIFT_INTERNAL_PREFIX):
|
||||
|
||||
if by_uuid:
|
||||
uuids = tuple(job_utils.find_possible_data_source_refs_by_uuid(
|
||||
job_execution.job_configs))
|
||||
# do a query here for id in uuids and path starts with swift-prefix
|
||||
if uuids and conductor.data_source_count(
|
||||
context.ctx(),
|
||||
id=uuids,
|
||||
url=su.SWIFT_INTERNAL_PREFIX+'%') > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user