Add options supporting DataSource identifiers in job_configs

This change adds options that allow DataSource objects to be
referenced by name or uuid in the job_configs dictionary of a
job_execution. If a reference to a DataSource is found, the path
information replaces the reference.

Note, references are partially resolved in early processing to
determine whether or not a proxy user must be created.  References
are fully resolved in run_job().

Implements: blueprint edp-data-sources-in-job-configs
Change-Id: I5be62b798b86a8aaf933c2cc6b6d5a252f0a8627
This commit is contained in:
Trevor McKay 2014-12-16 17:42:57 -05:00
parent c7dc7968db
commit 8750ddc121
16 changed files with 793 additions and 90 deletions

View File

@ -235,6 +235,14 @@ class LocalApi(object):
"""
return self._manager.data_source_get_all(context, **kwargs)
def data_source_count(self, context, **kwargs):
"""Count Data Sources filtered by **kwargs.
Uses sqlalchemy "in_" clause for any tuple values
Uses sqlalchemy "like" clause for any string values containing %
"""
return self._manager.data_source_count(context, **kwargs)
@r.wrap(r.DataSource)
def data_source_create(self, context, values):
"""Create a Data Source from the values dictionary."""

View File

@ -289,6 +289,14 @@ class ConductorManager(db_base.Base):
"""
return self.db.data_source_get_all(context, **kwargs)
def data_source_count(self, context, **kwargs):
"""Count Data Sources filtered by **kwargs.
Uses sqlalchemy "in_" clause for any tuple values
Uses sqlalchemy "like" clause for any string values containing %
"""
return self.db.data_source_count(context, **kwargs)
def data_source_create(self, context, values):
"""Create a Data Source from the values dictionary."""
values = copy.deepcopy(values)

View File

@ -262,6 +262,15 @@ def data_source_get_all(context, **kwargs):
return IMPL.data_source_get_all(context, **kwargs)
def data_source_count(context, **kwargs):
"""Count Data Sources filtered by **kwargs.
Uses sqlalchemy "in_" clause for any tuple values
Uses sqlalchemy "like" clause for any string values containing %
"""
return IMPL.data_source_count(context, **kwargs)
@to_dict
def data_source_create(context, values):
"""Create a Data Source from the values dictionary."""

View File

@ -99,6 +99,76 @@ def count_query(model, context, session=None, project_only=None):
return model_query(sa.func.count(model.id), context, session, project_only)
def in_filter(query, cls, search_opts):
"""Add 'in' filters for specified columns.
Add a sqlalchemy 'in' filter to the query for any entry in the
'search_opts' dict where the key is the name of a column in
'cls' and the value is a tuple.
This allows the value of a column to be matched
against multiple possible values (OR).
Return the modified query and any entries in search_opts
whose keys do not match columns or whose values are not
tuples.
:param query: a non-null query object
:param cls: the database model class that filters will apply to
:param search_opts: a dictionary whose key/value entries are interpreted as
column names and search values
:returns: a tuple containing the modified query and a dictionary of
unused search_opts
"""
if not search_opts:
return query, search_opts
remaining = {}
for k, v in six.iteritems(search_opts):
if type(v) == tuple and k in cls.__table__.columns:
col = cls.__table__.columns[k]
query = query.filter(col.in_(v))
else:
remaining[k] = v
return query, remaining
def like_filter(query, cls, search_opts):
"""Add 'like' filters for specified columns.
Add a sqlalchemy 'like' filter to the query for any entry in the
'search_opts' dict where the key is the name of a column in
'cls' and the value is a string containing '%'.
This allows the value of a column to be matched
against simple sql string patterns using LIKE and the
'%' wildcard.
Return the modified query and any entries in search_opts
whose keys do not match columns or whose values are not
strings containing '%'.
:param query: a non-null query object
:param cls: the database model class the filters will apply to
:param search_opts: a dictionary whose key/value entries are interpreted as
column names and search patterns
:returns: a tuple containing the modified query and a dictionary of
unused search_opts
"""
if not search_opts:
return query, search_opts
remaining = {}
for k, v in six.iteritems(search_opts):
if isinstance(v, six.string_types) and (
'%' in v and k in cls.__table__.columns):
col = cls.__table__.columns[k]
query = query.filter(col.like(v))
else:
remaining[k] = v
return query, remaining
def setup_db():
try:
engine = get_engine()
@ -489,6 +559,37 @@ def data_source_get(context, data_source_id):
return _data_source_get(context, get_session(), data_source_id)
def data_source_count(context, **kwargs):
"""Count DataSource objects filtered by search criteria in kwargs.
Entries in kwargs indicate column names and search values.
'in' filters will be used to search for any entries in kwargs
that name DataSource columns and have values of type tuple. This
allows column values to match multiple values (OR)
'like' filters will be used for any entries in kwargs that
name DataSource columns and have string values containing '%'.
This allows column values to match simple wildcards.
Any other entries in kwargs will be searched for using filter_by()
"""
query = model_query(m.DataSource, context)
query, kwargs = in_filter(query, m.DataSource, kwargs)
query, kwargs = like_filter(query, m.DataSource, kwargs)
# Use normal filter_by for remaining keys
try:
return query.filter_by(**kwargs).count()
except Exception as e:
if kwargs:
# If kwargs is non-empty then we assume this
# is a bad field reference. User asked for something
# that doesn't exist, so return empty list
return []
raise e
def data_source_get_all(context, **kwargs):
query = model_query(m.DataSource, context)
try:

View File

@ -21,8 +21,10 @@ import six
from sahara import conductor as c
from sahara import context
from sahara.openstack.common import uuidutils
from sahara.plugins import base as plugin_base
from sahara.service.edp.binary_retrievers import dispatch
from sahara.swift import swift_helper as sw
from sahara.utils import edp
from sahara.utils import remote
@ -39,6 +41,12 @@ CONF.register_opts(opts)
conductor = c.API
# Prefix used to mark data_source name references in arg lists
DATA_SOURCE_PREFIX = "datasource://"
DATA_SOURCE_SUBST_NAME = "edp.substitute_data_source_for_name"
DATA_SOURCE_SUBST_UUID = "edp.substitute_data_source_for_uuid"
def get_plugin(cluster):
return plugin_base.PLUGINS.get_plugin(cluster.plugin_name)
@ -94,3 +102,175 @@ def _append_slash_if_needed(path):
if path[-1] != '/':
path += '/'
return path
def may_contain_data_source_refs(job_configs):
def _check_data_source_ref_option(option):
truth = job_configs and (
job_configs.get('configs', {}).get(option))
# Config values specified in the UI may be
# passed as strings
return truth in (True, 'True')
return (
_check_data_source_ref_option(DATA_SOURCE_SUBST_NAME),
_check_data_source_ref_option(DATA_SOURCE_SUBST_UUID))
def _data_source_ref_search(job_configs, func, prune=lambda x: x):
"""Return a list of unique values in job_configs filtered by func().
Loop over the 'args', 'configs' and 'params' elements in
job_configs and return a list of all values for which
func(value) is True.
Optionally provide a 'prune' function that is applied
to values before they are added to the return value.
"""
args = set([prune(arg) for arg in job_configs.get(
'args', []) if func(arg)])
configs = set([prune(val) for val in six.itervalues(
job_configs.get('configs', {})) if func(val)])
params = set([prune(val) for val in six.itervalues(
job_configs.get('params', {})) if func(val)])
return list(args | configs | params)
def find_possible_data_source_refs_by_name(job_configs):
"""Find string values in job_configs starting with 'datasource://'.
Loop over the 'args', 'configs', and 'params' elements of
job_configs to find all values beginning with the prefix
'datasource://'. Return a list of unique values with the prefix
removed.
Note that for 'configs' and 'params', which are dictionaries, only
the values are considered and the keys are not relevant.
"""
def startswith(arg):
return isinstance(
arg,
six.string_types) and arg.startswith(DATA_SOURCE_PREFIX)
return _data_source_ref_search(job_configs,
startswith,
prune=lambda x: x[len(DATA_SOURCE_PREFIX):])
def find_possible_data_source_refs_by_uuid(job_configs):
"""Find string values in job_configs which are uuids.
Return a list of unique values in the 'args', 'configs', and 'params'
elements of job_configs which have the form of a uuid.
Note that for 'configs' and 'params', which are dictionaries, only
the values are considered and the keys are not relevant.
"""
return _data_source_ref_search(job_configs, uuidutils.is_uuid_like)
def _add_credentials_for_data_sources(ds_list, configs):
username = password = None
for src in ds_list:
if src.type == "swift" and hasattr(src, "credentials"):
if "user" in src.credentials:
username = src.credentials['user']
if "password" in src.credentials:
password = src.credentials['password']
break
# Don't overwrite if there is already a value here
if configs.get(sw.HADOOP_SWIFT_USERNAME, None) is None and (
username is not None):
configs[sw.HADOOP_SWIFT_USERNAME] = username
if configs.get(sw.HADOOP_SWIFT_PASSWORD, None) is None and (
password is not None):
configs[sw.HADOOP_SWIFT_PASSWORD] = password
def resolve_data_source_references(job_configs):
"""Resolve possible data_source references in job_configs.
Look for any string values in the 'args', 'configs', and 'params'
elements of job_configs which start with 'datasource://' or have
the form of a uuid.
For values beginning with 'datasource://', strip off the prefix
and search for a DataSource object with a name that matches the
value.
For values having the form of a uuid, search for a DataSource object
with an id that matches the value.
If a DataSource object is found for the value, replace the value
with the URL from the DataSource object. If any DataSource objects
are found which reference swift paths and contain credentials, set
credential configuration values in job_configs (use the first set
of swift credentials found).
If no values are resolved, return an empty list and a reference
to job_configs.
If any values are resolved, return a list of the referenced
data_source objects and a copy of job_configs with all of the
references replaced with URLs.
"""
by_name, by_uuid = may_contain_data_source_refs(job_configs)
if not (by_name or by_uuid):
return [], job_configs
ctx = context.ctx()
ds_seen = {}
new_configs = {}
def _resolve(value):
kwargs = {}
if by_name and isinstance(
value,
six.string_types) and value.startswith(DATA_SOURCE_PREFIX):
value = value[len(DATA_SOURCE_PREFIX):]
kwargs['name'] = value
elif by_uuid and uuidutils.is_uuid_like(value):
kwargs['id'] = value
if kwargs:
# Name and id are both unique constraints so if there
# is more than 1 something is really wrong
ds = conductor.data_source_get_all(ctx, **kwargs)
if len(ds) == 1:
ds = ds[0]
ds_seen[ds.id] = ds
return ds.url
return value
# Loop over configs/params/args and look up each value as a data_source.
# If we find it, replace the value. In all cases, we've produced a
# copy which is not a FrozenClass type and can be updated.
new_configs['configs'] = {
k: _resolve(v) for k, v in six.iteritems(
job_configs.get('configs', {}))}
new_configs['params'] = {
k: _resolve(v) for k, v in six.iteritems(
job_configs.get('params', {}))}
new_configs['args'] = [_resolve(a) for a in job_configs.get('args', [])]
# If we didn't resolve anything we might as well return the original
ds_seen = ds_seen.values()
if not ds_seen:
return [], job_configs
# If there are no proxy_configs and the user has not already set configs
# for swift credentials, set those configs based on data_sources we found
if not job_configs.get('proxy_configs'):
_add_credentials_for_data_sources(ds_seen, new_configs['configs'])
else:
# we'll need to copy these, too, so job_configs is complete
new_configs['proxy_configs'] = {
k: v for k, v in six.iteritems(job_configs.get('proxy_configs'))}
return ds_seen, new_configs

View File

@ -81,9 +81,20 @@ class OozieJobEngine(base_engine.JobEngine):
job = conductor.job_get(ctx, job_execution.job_id)
input_source, output_source = job_utils.get_data_sources(job_execution,
job)
proxy_configs = job_execution.job_configs.get('proxy_configs')
for data_source in [input_source, output_source]:
# Updated_job_configs will be a copy of job_execution.job_configs with
# any name or uuid references to data_sources resolved to paths
# assuming substitution is enabled.
# If substitution is not enabled then updated_job_configs will
# just be a reference to job_execution.job_configs to avoid a copy.
# Additional_sources will be a list of any data_sources found.
additional_sources, updated_job_configs = (
job_utils.resolve_data_source_references(job_execution.job_configs)
)
proxy_configs = updated_job_configs.get('proxy_configs')
for data_source in [input_source, output_source] + additional_sources:
if data_source and data_source.type == 'hdfs':
h.configure_cluster_for_hdfs(self.cluster, data_source)
break
@ -99,7 +110,8 @@ class OozieJobEngine(base_engine.JobEngine):
proxy_configs)
wf_xml = workflow_factory.get_workflow_xml(
job, self.cluster, job_execution, input_source, output_source,
job, self.cluster, updated_job_configs,
input_source, output_source,
hdfs_user)
path_to_workflow = self._upload_workflow_file(oozie_server, wf_dir,

View File

@ -140,14 +140,14 @@ class PigFactory(BaseFactory):
def get_script_name(self, job):
return conductor.job_main_name(context.ctx(), job)
def get_workflow_xml(self, cluster, execution, input_data, output_data,
def get_workflow_xml(self, cluster, job_configs, input_data, output_data,
hdfs_user):
proxy_configs = execution.job_configs.get('proxy_configs')
proxy_configs = job_configs.get('proxy_configs')
job_dict = {'configs': self.get_configs(input_data, output_data,
proxy_configs),
'params': self.get_params(input_data, output_data),
'args': []}
self.update_job_dict(job_dict, execution.job_configs)
self.update_job_dict(job_dict, job_configs)
creator = pig_workflow.PigWorkflowCreator()
creator.build_workflow_xml(self.name,
configuration=job_dict['configs'],
@ -165,13 +165,13 @@ class HiveFactory(BaseFactory):
def get_script_name(self, job):
return conductor.job_main_name(context.ctx(), job)
def get_workflow_xml(self, cluster, execution, input_data, output_data,
def get_workflow_xml(self, cluster, job_configs, input_data, output_data,
hdfs_user):
proxy_configs = execution.job_configs.get('proxy_configs')
proxy_configs = job_configs.get('proxy_configs')
job_dict = {'configs': self.get_configs(input_data, output_data,
proxy_configs),
'params': self.get_params(input_data, output_data)}
self.update_job_dict(job_dict, execution.job_configs)
self.update_job_dict(job_dict, job_configs)
creator = hive_workflow.HiveWorkflowCreator()
creator.build_workflow_xml(self.name,
@ -196,12 +196,12 @@ class MapReduceFactory(BaseFactory):
return dict((k[len(prefix):], v) for (k, v) in six.iteritems(
job_dict['edp_configs']) if k.startswith(prefix))
def get_workflow_xml(self, cluster, execution, input_data, output_data,
def get_workflow_xml(self, cluster, job_configs, input_data, output_data,
hdfs_user):
proxy_configs = execution.job_configs.get('proxy_configs')
proxy_configs = job_configs.get('proxy_configs')
job_dict = {'configs': self.get_configs(input_data, output_data,
proxy_configs)}
self.update_job_dict(job_dict, execution.job_configs)
self.update_job_dict(job_dict, job_configs)
creator = mapreduce_workflow.MapReduceWorkFlowCreator()
creator.build_workflow_xml(configuration=job_dict['configs'],
streaming=self._get_streaming(job_dict))
@ -230,11 +230,11 @@ class JavaFactory(BaseFactory):
return configs
def get_workflow_xml(self, cluster, execution, *args, **kwargs):
proxy_configs = execution.job_configs.get('proxy_configs')
def get_workflow_xml(self, cluster, job_configs, *args, **kwargs):
proxy_configs = job_configs.get('proxy_configs')
job_dict = {'configs': self.get_configs(proxy_configs=proxy_configs),
'args': []}
self.update_job_dict(job_dict, execution.job_configs)
self.update_job_dict(job_dict, job_configs)
main_class, java_opts = self._get_java_configs(job_dict)
creator = java_workflow.JavaWorkflowCreator()
@ -264,9 +264,9 @@ def _get_creator(job):
return type_map[job.type]()
def get_workflow_xml(job, cluster, execution, *args, **kwargs):
def get_workflow_xml(job, cluster, job_configs, *args, **kwargs):
return _get_creator(job).get_workflow_xml(
cluster, execution, *args, **kwargs)
cluster, job_configs, *args, **kwargs)
def get_possible_job_config(job_type):

View File

@ -114,7 +114,11 @@ class SparkJobEngine(base_engine.JobEngine):
ctx = context.ctx()
job = conductor.job_get(ctx, job_execution.job_id)
proxy_configs = job_execution.job_configs.get('proxy_configs')
additional_sources, updated_job_configs = (
job_utils.resolve_data_source_references(job_execution.job_configs)
)
proxy_configs = updated_job_configs.get('proxy_configs')
# We'll always run the driver program on the master
master = plugin_utils.get_instance(self.cluster, "master")
@ -150,11 +154,11 @@ class SparkJobEngine(base_engine.JobEngine):
self.cluster),
"bin/spark-submit")
job_class = job_execution.job_configs.configs["edp.java.main_class"]
job_class = updated_job_configs['configs']["edp.java.main_class"]
# TODO(tmckay): we need to clean up wf_dirs on long running clusters
# TODO(tmckay): probably allow for general options to spark-submit
args = " ".join(job_execution.job_configs.get('args', []))
args = " ".join(updated_job_configs.get('args', []))
# The redirects of stdout and stderr will preserve output in the wf_dir
cmd = "%s %s --class %s %s --master spark://%s:%s %s" % (

View File

@ -21,7 +21,7 @@ import uuid
import fixtures
import six
from sahara.swift import swift_helper as sw
from sahara.service.edp import job_utils
from sahara.tests.integration.tests import base
from sahara.utils import edp
@ -204,17 +204,13 @@ class EDPTest(base.ITestCase):
)
)
def _add_swift_configs(self, configs):
def _enable_substitution(self, configs):
if "configs" not in configs:
configs["configs"] = {}
if sw.HADOOP_SWIFT_USERNAME not in configs["configs"]:
configs["configs"][
sw.HADOOP_SWIFT_USERNAME] = self.common_config.OS_USERNAME
if sw.HADOOP_SWIFT_PASSWORD not in configs["configs"]:
configs["configs"][
sw.HADOOP_SWIFT_PASSWORD] = self.common_config.OS_PASSWORD
configs['configs'][job_utils.DATA_SOURCE_SUBST_NAME] = True
configs['configs'][job_utils.DATA_SOURCE_SUBST_UUID] = True
@base.skip_test('SKIP_EDP_TEST', 'Test for EDP was skipped.')
def check_edp_hive(self):
@ -279,17 +275,14 @@ class EDPTest(base.ITestCase):
output_type = "swift"
output_url = 'swift://%s.sahara/output' % container_name
# Java jobs don't use data sources. Input/output paths must
# be passed as args with corresponding username/password configs
if not edp.compare_job_type(job_type,
edp.JOB_TYPE_JAVA,
edp.JOB_TYPE_SPARK):
input_id = self._create_data_source(
'input-%s' % str(uuid.uuid4())[:8], 'swift',
swift_input_url)
output_id = self._create_data_source(
'output-%s' % str(uuid.uuid4())[:8], output_type,
output_url)
input_name = 'input-%s' % str(uuid.uuid4())[:8]
input_id = self._create_data_source(input_name,
'swift', swift_input_url)
output_name = 'output-%s' % str(uuid.uuid4())[:8]
output_id = self._create_data_source(output_name,
output_type,
output_url)
if job_data_list:
if swift_binaries:
@ -329,11 +322,13 @@ class EDPTest(base.ITestCase):
# if the caller has requested it...
if edp.compare_job_type(
job_type, edp.JOB_TYPE_JAVA) and pass_input_output_args:
self._add_swift_configs(configs)
self._enable_substitution(configs)
input_arg = job_utils.DATA_SOURCE_PREFIX + input_name
output_arg = output_id
if "args" in configs:
configs["args"].extend([swift_input_url, output_url])
configs["args"].extend([input_arg, output_arg])
else:
configs["args"] = [swift_input_url, output_url]
configs["args"] = [input_arg, output_arg]
job_execution = self.sahara.job_executions.create(
job_id, self.cluster_id, input_id, output_id,

View File

@ -158,6 +158,57 @@ class DataSourceTest(test_base.ConductorManagerTestCase):
lst = self.api.data_source_get_all(ctx, **{'badfield': 'somevalue'})
self.assertEqual(len(lst), 0)
def test_data_source_count_in(self):
ctx = context.ctx()
ctx.tenant_id = SAMPLE_DATA_SOURCE['tenant_id']
src = copy.copy(SAMPLE_DATA_SOURCE)
self.api.data_source_create(ctx, src)
cnt = self.api.data_source_count(ctx, name='ngt_test')
self.assertEqual(cnt, 1)
cnt = self.api.data_source_count(ctx, name=('ngt_test',
'test2', 'test3'))
self.assertEqual(cnt, 1)
cnt = self.api.data_source_count(ctx, name=('test1',
'test2', 'test3'))
self.assertEqual(cnt, 0)
lst = self.api.data_source_get_all(ctx, name='ngt_test')
myid = lst[0]['id']
cnt = self.api.data_source_count(ctx,
name=('ngt_test', 'test2', 'test3'),
id=myid)
self.assertEqual(cnt, 1)
cnt = self.api.data_source_count(ctx,
name=('ngt_test', 'test2', 'test3'),
id=(myid, '2'))
self.assertEqual(cnt, 1)
def test_data_source_count_like(self):
ctx = context.ctx()
ctx.tenant_id = SAMPLE_DATA_SOURCE['tenant_id']
src = copy.copy(SAMPLE_DATA_SOURCE)
self.api.data_source_create(ctx, src)
cnt = self.api.data_source_count(ctx, name='ngt_test')
self.assertEqual(cnt, 1)
cnt = self.api.data_source_count(ctx, name='ngt%')
self.assertEqual(cnt, 1)
cnt = self.api.data_source_count(ctx,
name=('ngt_test',),
url='localhost%')
self.assertEqual(cnt, 1)
cnt = self.api.data_source_count(ctx,
name=('ngt_test',),
url='localhost')
self.assertEqual(cnt, 0)
class JobExecutionTest(test_base.ConductorManagerTestCase):
def test_crud_operation_create_list_delete_update(self):

View File

@ -68,7 +68,7 @@ def create_cluster(plugin_name='vanilla', hadoop_version='1.2.1'):
return cluster
def create_data_source(url):
def create_data_source(url, name=None, id=None):
data_source = mock.Mock()
data_source.url = url
if url.startswith("swift"):
@ -77,6 +77,10 @@ def create_data_source(url):
'password': 'admin1'}
elif url.startswith("hdfs"):
data_source.type = "hdfs"
if name is not None:
data_source.name = name
if id is not None:
data_source.id = id
return data_source

View File

@ -323,11 +323,6 @@ class TestSpark(base.SaharaTestCase):
upload_job_files, get_config_value, get_remote,
job_exec_get):
def fix_get(field, default=None):
if field == "args":
return ["input_arg", "output_arg"]
return default
eng = se.SparkJobEngine("cluster")
job = mock.Mock()
@ -335,9 +330,9 @@ class TestSpark(base.SaharaTestCase):
job_get.return_value = job
job_exec = mock.Mock()
job_exec.job_configs.configs = {"edp.java.main_class":
"org.me.myclass"}
job_exec.job_configs.get = fix_get
job_exec.job_configs = {'configs': {"edp.java.main_class":
"org.me.myclass"},
'args': ['input_arg', 'output_arg']}
master = mock.Mock()
get_instance.return_value = master

View File

@ -120,8 +120,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
output_data = u.create_data_source('swift://ex/o')
res = workflow_factory.get_workflow_xml(
job, u.create_cluster(), job_exec, input_data, output_data,
'hadoop')
job, u.create_cluster(), job_exec.job_configs,
input_data, output_data, 'hadoop')
self.assertIn("""
<param>INPUT=swift://ex.sahara/i</param>
@ -147,8 +147,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
job, job_exec = u.create_job_exec(edp.JOB_TYPE_PIG, proxy=True)
res = workflow_factory.get_workflow_xml(
job, u.create_cluster(), job_exec, input_data, output_data,
'hadoop')
job, u.create_cluster(), job_exec.job_configs,
input_data, output_data, 'hadoop')
self.assertIn("""
<configuration>
@ -181,8 +181,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
output_data = u.create_data_source('hdfs://user/hadoop/out')
res = workflow_factory.get_workflow_xml(
job, u.create_cluster(), job_exec, input_data, output_data,
'hadoop')
job, u.create_cluster(), job_exec.job_configs,
input_data, output_data, 'hadoop')
self.assertIn("""
<configuration>
@ -200,8 +200,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
output_data = u.create_data_source('swift://ex/o')
res = workflow_factory.get_workflow_xml(
job, u.create_cluster(), job_exec, input_data, output_data,
'hadoop')
job, u.create_cluster(), job_exec.job_configs,
input_data, output_data, 'hadoop')
self.assertIn("""
<configuration>
@ -221,8 +221,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
output_data = u.create_data_source('hdfs://user/hadoop/out')
res = workflow_factory.get_workflow_xml(
job, u.create_cluster(), job_exec, input_data, output_data,
'hadoop')
job, u.create_cluster(), job_exec.job_configs,
input_data, output_data, 'hadoop')
self.assertIn("""
<configuration>
@ -246,8 +246,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
output_data = u.create_data_source('swift://ex/o')
res = workflow_factory.get_workflow_xml(
job, u.create_cluster(), job_exec, input_data, output_data,
'hadoop')
job, u.create_cluster(), job_exec.job_configs,
input_data, output_data, 'hadoop')
if streaming:
self.assertIn("""
@ -288,8 +288,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
job, job_exec = u.create_job_exec(job_type, proxy=True)
res = workflow_factory.get_workflow_xml(
job, u.create_cluster(), job_exec, input_data, output_data,
'hadoop')
job, u.create_cluster(), job_exec.job_configs,
input_data, output_data, 'hadoop')
self.assertIn("""
<property>
@ -331,7 +331,7 @@ class TestJobManager(base.SaharaWithDbTestCase):
job, job_exec = u.create_job_exec(edp.JOB_TYPE_JAVA, configs)
res = workflow_factory.get_workflow_xml(
job, u.create_cluster(), job_exec)
job, u.create_cluster(), job_exec.job_configs)
self.assertIn("""
<configuration>
@ -361,7 +361,7 @@ class TestJobManager(base.SaharaWithDbTestCase):
job, job_exec = u.create_job_exec(edp.JOB_TYPE_JAVA, configs,
proxy=True)
res = workflow_factory.get_workflow_xml(job, u.create_cluster(),
job_exec)
job_exec.job_configs)
self.assertIn("""
<configuration>
@ -397,8 +397,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
output_data = u.create_data_source('swift://ex/o')
res = workflow_factory.get_workflow_xml(
job, u.create_cluster(), job_exec, input_data, output_data,
'hadoop')
job, u.create_cluster(), job_exec.job_configs,
input_data, output_data, 'hadoop')
doc = xml.parseString(res)
hive = doc.getElementsByTagName('hive')[0]
@ -425,8 +425,8 @@ class TestJobManager(base.SaharaWithDbTestCase):
job, job_exec = u.create_job_exec(edp.JOB_TYPE_HIVE, proxy=True)
res = workflow_factory.get_workflow_xml(
job, u.create_cluster(), job_exec, input_data, output_data,
'hadoop')
job, u.create_cluster(), job_exec.job_configs,
input_data, output_data, 'hadoop')
doc = xml.parseString(res)
hive = doc.getElementsByTagName('hive')[0]

View File

@ -0,0 +1,249 @@
# Copyright (c) 2013 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
import mock
import six
import testtools
from sahara import conductor as cond
from sahara.service.edp import job_utils
from sahara.tests.unit.service.edp import edp_test_utils as u
conductor = cond.API
class JobUtilsTestCase(testtools.TestCase):
def setUp(self):
super(JobUtilsTestCase, self).setUp()
def test_args_may_contain_data_sources(self):
job_configs = None
# No configs, default false
by_name, by_uuid = job_utils.may_contain_data_source_refs(job_configs)
self.assertFalse(by_name | by_uuid)
# Empty configs, default false
job_configs = {'configs': {}}
by_name, by_uuid = job_utils.may_contain_data_source_refs(job_configs)
self.assertFalse(by_name | by_uuid)
job_configs['configs'] = {job_utils.DATA_SOURCE_SUBST_NAME: True,
job_utils.DATA_SOURCE_SUBST_UUID: True}
by_name, by_uuid = job_utils.may_contain_data_source_refs(job_configs)
self.assertTrue(by_name & by_uuid)
job_configs['configs'][job_utils.DATA_SOURCE_SUBST_NAME] = False
by_name, by_uuid = job_utils.may_contain_data_source_refs(job_configs)
self.assertFalse(by_name)
self.assertTrue(by_uuid)
job_configs['configs'][job_utils.DATA_SOURCE_SUBST_UUID] = False
by_name, by_uuid = job_utils.may_contain_data_source_refs(job_configs)
self.assertFalse(by_name | by_uuid)
job_configs['configs'] = {job_utils.DATA_SOURCE_SUBST_NAME: 'True',
job_utils.DATA_SOURCE_SUBST_UUID: 'Fish'}
by_name, by_uuid = job_utils.may_contain_data_source_refs(job_configs)
self.assertTrue(by_name)
self.assertFalse(by_uuid)
def test_find_possible_data_source_refs_by_name(self):
id = six.text_type(uuid.uuid4())
job_configs = {}
self.assertEqual([],
job_utils.find_possible_data_source_refs_by_name(
job_configs))
name_ref = job_utils.DATA_SOURCE_PREFIX+'name'
name_ref2 = name_ref+'2'
job_configs = {'args': ['first', id],
'configs': {'config': 'value'},
'params': {'param': 'value'}}
self.assertEqual([],
job_utils.find_possible_data_source_refs_by_name(
job_configs))
job_configs = {'args': [name_ref, id],
'configs': {'config': 'value'},
'params': {'param': 'value'}}
self.assertEqual(
['name'],
job_utils.find_possible_data_source_refs_by_name(job_configs))
job_configs = {'args': ['first', id],
'configs': {'config': name_ref},
'params': {'param': 'value'}}
self.assertEqual(
['name'],
job_utils.find_possible_data_source_refs_by_name(job_configs))
job_configs = {'args': ['first', id],
'configs': {'config': 'value'},
'params': {'param': name_ref}}
self.assertEqual(
['name'],
job_utils.find_possible_data_source_refs_by_name(job_configs))
job_configs = {'args': [name_ref, name_ref2, id],
'configs': {'config': name_ref},
'params': {'param': name_ref}}
self.assertItemsEqual(
['name', 'name2'],
job_utils.find_possible_data_source_refs_by_name(job_configs))
def test_find_possible_data_source_refs_by_uuid(self):
job_configs = {}
name_ref = job_utils.DATA_SOURCE_PREFIX+'name'
self.assertEqual([],
job_utils.find_possible_data_source_refs_by_uuid(
job_configs))
id = six.text_type(uuid.uuid4())
job_configs = {'args': ['first', name_ref],
'configs': {'config': 'value'},
'params': {'param': 'value'}}
self.assertEqual([],
job_utils.find_possible_data_source_refs_by_uuid(
job_configs))
job_configs = {'args': [id, name_ref],
'configs': {'config': 'value'},
'params': {'param': 'value'}}
self.assertEqual(
[id],
job_utils.find_possible_data_source_refs_by_uuid(job_configs))
job_configs = {'args': ['first', name_ref],
'configs': {'config': id},
'params': {'param': 'value'}}
self.assertEqual(
[id],
job_utils.find_possible_data_source_refs_by_uuid(job_configs))
job_configs = {'args': ['first', name_ref],
'configs': {'config': 'value'},
'params': {'param': id}}
self.assertEqual(
[id],
job_utils.find_possible_data_source_refs_by_uuid(job_configs))
id2 = six.text_type(uuid.uuid4())
job_configs = {'args': [id, id2, name_ref],
'configs': {'config': id},
'params': {'param': id}}
self.assertItemsEqual([id, id2],
job_utils.find_possible_data_source_refs_by_uuid(
job_configs))
@mock.patch('sahara.context.ctx')
@mock.patch('sahara.conductor.API.data_source_get_all')
def test_resolve_data_source_refs(self, data_source_get_all, ctx):
ctx.return_value = 'dummy'
name_ref = job_utils.DATA_SOURCE_PREFIX+'input'
input = u.create_data_source("swift://container/input",
name="input",
id=six.text_type(uuid.uuid4()))
output = u.create_data_source("swift://container/output",
name="output",
id=six.text_type(uuid.uuid4()))
by_name = {'input': input,
'output': output}
by_id = {input.id: input,
output.id: output}
# Pretend to be the database
def _get_all(ctx, **kwargs):
name = kwargs.get('name')
if name in by_name:
name_list = [by_name[name]]
else:
name_list = []
id = kwargs.get('id')
if id in by_id:
id_list = [by_id[id]]
else:
id_list = []
return list(set(name_list + id_list))
data_source_get_all.side_effect = _get_all
job_configs = {
'configs': {
job_utils.DATA_SOURCE_SUBST_NAME: True,
job_utils.DATA_SOURCE_SUBST_UUID: True},
'args': [name_ref, output.id, input.id]}
ds, nc = job_utils.resolve_data_source_references(job_configs)
self.assertEqual(len(ds), 2)
self.assertEqual(nc['args'], [input.url, output.url, input.url])
# Swift configs should be filled in since they were blank
self.assertEqual(nc['configs']['fs.swift.service.sahara.username'],
input.credentials['user'])
self.assertEqual(nc['configs']['fs.swift.service.sahara.password'],
input.credentials['password'])
job_configs['configs'] = {'fs.swift.service.sahara.username': 'sam',
'fs.swift.service.sahara.password': 'gamgee',
job_utils.DATA_SOURCE_SUBST_NAME: False,
job_utils.DATA_SOURCE_SUBST_UUID: True}
ds, nc = job_utils.resolve_data_source_references(job_configs)
self.assertEqual(len(ds), 2)
self.assertEqual(nc['args'], [name_ref, output.url, input.url])
# Swift configs should not be overwritten
self.assertEqual(nc['configs'], job_configs['configs'])
job_configs['configs'] = {job_utils.DATA_SOURCE_SUBST_NAME: True,
job_utils.DATA_SOURCE_SUBST_UUID: False}
job_configs['proxy_configs'] = {'proxy_username': 'john',
'proxy_password': 'smith',
'proxy_trust_id': 'trustme'}
ds, nc = job_utils.resolve_data_source_references(job_configs)
self.assertEqual(len(ds), 1)
self.assertEqual(nc['args'], [input.url, output.id, input.id])
# Swift configs should be empty and proxy configs should be preserved
self.assertEqual(nc['configs'], job_configs['configs'])
self.assertEqual(nc['proxy_configs'], job_configs['proxy_configs'])
# Substitution not enabled
job_configs['configs'] = {job_utils.DATA_SOURCE_SUBST_NAME: False,
job_utils.DATA_SOURCE_SUBST_UUID: False}
ds, nc = job_utils.resolve_data_source_references(job_configs)
self.assertEqual(len(ds), 0)
self.assertEqual(nc['args'], job_configs['args'])
self.assertEqual(nc['configs'], job_configs['configs'])
# Substitution enabled but no values to modify
job_configs['configs'] = {job_utils.DATA_SOURCE_SUBST_NAME: True,
job_utils.DATA_SOURCE_SUBST_UUID: True}
job_configs['args'] = ['val1', 'val2', 'val3']
ds, nc = job_utils.resolve_data_source_references(job_configs)
self.assertEqual(len(ds), 0)
self.assertEqual(nc['args'], job_configs['args'])
self.assertEqual(nc['configs'], job_configs['configs'])

View File

@ -13,8 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import mock
import uuid
import mock
import six
from sahara.service.edp import job_utils
from sahara.tests.unit import base
from sahara.utils import proxy as p
@ -25,7 +29,14 @@ class TestProxyUtils(base.SaharaWithDbTestCase):
@mock.patch('sahara.conductor.API.job_get')
@mock.patch('sahara.conductor.API.data_source_get')
def test_job_execution_requires_proxy_user(self, data_source, job):
@mock.patch('sahara.conductor.API.data_source_count')
@mock.patch('sahara.context.ctx')
def test_job_execution_requires_proxy_user(self,
ctx,
data_source_count,
data_source,
job):
self.override_config('use_domain_for_proxy_users', True)
job_execution = mock.Mock(input_id=1,
output_id=2,
@ -44,8 +55,47 @@ class TestProxyUtils(base.SaharaWithDbTestCase):
libs=[mock.Mock(url='swift://container/object')])
self.assertTrue(p.job_execution_requires_proxy_user(job_execution))
job_execution.job_configs['args'] = ['swift://container/object']
job_execution.job_configs = {'args': ['swift://container/object']}
job.return_value = mock.Mock(
mains=[],
libs=[])
self.assertTrue(p.job_execution_requires_proxy_user(job_execution))
job_execution.job_configs = {
'configs': {'key': 'swift://container/object'}}
self.assertTrue(p.job_execution_requires_proxy_user(job_execution))
job_execution.job_configs = {
'params': {'key': 'swift://container/object'}}
self.assertTrue(p.job_execution_requires_proxy_user(job_execution))
data_source_count.return_value = 0
job_execution.job_configs = {
'configs': {job_utils.DATA_SOURCE_SUBST_NAME: True}}
job.return_value = mock.Mock(
mains=[],
libs=[])
self.assertFalse(p.job_execution_requires_proxy_user(job_execution))
ctx.return_value = 'dummy'
data_source_count.return_value = 1
job_execution.job_configs = {
'configs': {job_utils.DATA_SOURCE_SUBST_NAME: True},
'args': [job_utils.DATA_SOURCE_PREFIX+'somevalue']}
self.assertTrue(p.job_execution_requires_proxy_user(job_execution))
data_source_count.assert_called_with('dummy',
name=('somevalue',),
url='swift://%')
data_source_count.reset_mock()
data_source_count.return_value = 1
myid = six.text_type(uuid.uuid4())
job_execution.job_configs = {
'configs': {job_utils.DATA_SOURCE_SUBST_UUID: True},
'args': [myid]}
job.return_value = mock.Mock(
mains=[],
libs=[])
self.assertTrue(p.job_execution_requires_proxy_user(job_execution))
data_source_count.assert_called_with('dummy',
id=(myid,),
url='swift://%')

View File

@ -23,6 +23,7 @@ from sahara import context
from sahara import exceptions as ex
from sahara.i18n import _
from sahara.openstack.common import log as logging
from sahara.service.edp import job_utils
from sahara.service import trusts as t
from sahara.swift import utils as su
from sahara.utils.openstack import keystone as k
@ -180,27 +181,63 @@ def domain_for_proxy():
def job_execution_requires_proxy_user(job_execution):
'''Returns True if the job execution requires a proxy user.'''
def _check_values(values):
return any(value.startswith(
su.SWIFT_INTERNAL_PREFIX) for value in values if (
isinstance(value, six.string_types)))
if CONF.use_domain_for_proxy_users is False:
return False
input_ds = conductor.data_source_get(context.ctx(),
job_execution.input_id)
if input_ds and input_ds.url.startswith(su.SWIFT_INTERNAL_PREFIX):
return True
output_ds = conductor.data_source_get(context.ctx(),
job_execution.output_id)
if output_ds and output_ds.url.startswith(su.SWIFT_INTERNAL_PREFIX):
return True
if job_execution.job_configs.get('args'):
for arg in job_execution.job_configs['args']:
if arg.startswith(su.SWIFT_INTERNAL_PREFIX):
return True
paths = [conductor.data_source_get(context.ctx(), job_execution.output_id),
conductor.data_source_get(context.ctx(), job_execution.input_id)]
if _check_values(ds.url for ds in paths if ds):
return True
if _check_values(six.itervalues(
job_execution.job_configs.get('configs', {}))):
return True
if _check_values(six.itervalues(
job_execution.job_configs.get('params', {}))):
return True
if _check_values(job_execution.job_configs.get('args', [])):
return True
job = conductor.job_get(context.ctx(), job_execution.job_id)
for main in job.mains:
if main.url.startswith(su.SWIFT_INTERNAL_PREFIX):
if _check_values(main.url for main in job.mains):
return True
if _check_values(lib.url for lib in job.libs):
return True
# We did the simple checks, now if data_source referencing is
# enabled and we have values that could be a name or uuid,
# query for data_sources that match and contain a swift path
by_name, by_uuid = job_utils.may_contain_data_source_refs(
job_execution.job_configs)
if by_name:
names = tuple(job_utils.find_possible_data_source_refs_by_name(
job_execution.job_configs))
# do a query here for name in names and path starts with swift-prefix
if names and conductor.data_source_count(
context.ctx(),
name=names,
url=su.SWIFT_INTERNAL_PREFIX+'%') > 0:
return True
for lib in job.libs:
if lib.url.startswith(su.SWIFT_INTERNAL_PREFIX):
if by_uuid:
uuids = tuple(job_utils.find_possible_data_source_refs_by_uuid(
job_execution.job_configs))
# do a query here for id in uuids and path starts with swift-prefix
if uuids and conductor.data_source_count(
context.ctx(),
id=uuids,
url=su.SWIFT_INTERNAL_PREFIX+'%') > 0:
return True
return False