Commit 4dfea638 authored by Wout De Nolf's avatar Wout De Nolf
Browse files

Merge branch 'explicit_variable_passing' into 'main'

pass variables explicitly to allow non-persistent execution

See merge request !6
parents c3fb7bc5 7471c347
Pipeline #51228 passed with stages
in 58 seconds
......@@ -2,58 +2,114 @@
https://docs.dask.org/en/latest/scheduler-overview.html
"""
import json
import logging
from dask.distributed import Client
from dask.threaded import get as multithreading_scheduler
from dask.multiprocessing import get as multiprocessing_scheduler
from dask import get as sequential_scheduler
from ewokscore import load_graph
from ewokscore.inittask import instantiate_task
from ewokscore.inittask import add_dynamic_inputs
from ewokscore.graph import ewoks_jsonload_hook
def execute_task(node_name, *inputs):
node_name = node_name[:-3]
info = inputs[0]
ewoksgraph = load_graph(info["ewoksgraph"])
task = ewoksgraph.instantiate_task_static(node_name, varinfo=info["varinfo"])
task.execute()
return info
logger = logging.getLogger(__name__)
def convert_graph(ewoksgraph, varinfo):
def execute_task(execinfo, *inputs):
execinfo = json.loads(execinfo, object_pairs_hook=ewoks_jsonload_hook)
dynamic_inputs = dict()
for source_results, link_attrs in zip(inputs, execinfo["link_attrs"]):
add_dynamic_inputs(dynamic_inputs, link_attrs, source_results)
task = instantiate_task(
execinfo["node_attrs"],
node_name=execinfo["node_name"],
inputs=dynamic_inputs,
varinfo=execinfo["varinfo"],
)
try:
task.execute()
except Exception as e:
if execinfo["enable_logging"]:
logger.error(
"\nEXECUTE {} {}\n INPUTS: {}\n ERROR: {}".format(
execinfo["node_name"],
repr(task),
task.input_values,
e,
),
)
raise
if execinfo["enable_logging"]:
logger.info(
"\nEXECUTE {} {}\n INPUTS: {}\n OUTPUTS: {}".format(
execinfo["node_name"],
repr(task),
task.input_values,
task.output_values,
),
)
return task.output_transfer_data
def convert_graph(ewoksgraph, varinfo, enable_logging=False):
daskgraph = dict()
for target in ewoksgraph.graph.nodes:
for target, node_attrs in ewoksgraph.graph.nodes.items():
sources = tuple(source for source in ewoksgraph.predecessors(target))
if not sources:
sources = ({"ewoksgraph": ewoksgraph, "varinfo": varinfo},)
partial = (execute_task, target + "...")
daskgraph[target] = partial + sources
link_attrs = tuple(ewoksgraph.graph[source][target] for source in sources)
execinfo = {
"node_name": target,
"node_attrs": node_attrs,
"link_attrs": link_attrs,
"varinfo": varinfo,
"enable_logging": enable_logging,
}
# Note: the execinfo is serialized to prevent dask
# from interpreting node names as task results
daskgraph[target] = (execute_task, json.dumps(execinfo)) + sources
return daskgraph
def execute_graph(graph, representation=None, varinfo=None, scheduler=None):
ewoksgraph = load_graph(source=graph, representation=representation)
def execute_graph(
graph,
representation=None,
varinfo=None,
scheduler=None,
log_task_execution=False,
results_of_all_nodes=False,
**load_options,
):
ewoksgraph = load_graph(source=graph, representation=representation, **load_options)
if ewoksgraph.is_cyclic:
raise RuntimeError("Dask can only execute DAGs")
if ewoksgraph.has_conditional_links:
raise RuntimeError("Dask cannot handle conditional links")
daskgraph = convert_graph(ewoksgraph, varinfo)
daskgraph = convert_graph(ewoksgraph, varinfo, enable_logging=log_task_execution)
nodes = list()
for node in ewoksgraph.graph.nodes:
if len(list(ewoksgraph.graph.successors(node))) == 0:
nodes.append(node)
if results_of_all_nodes:
nodes = list(ewoksgraph.graph.nodes)
else:
nodes = list(ewoksgraph.result_nodes())
if scheduler is None:
sequential_scheduler(daskgraph, nodes)
results = sequential_scheduler(daskgraph, nodes)
elif isinstance(scheduler, str):
if scheduler == "multiprocessing":
multiprocessing_scheduler(daskgraph, nodes)
results = multiprocessing_scheduler(daskgraph, nodes)
elif scheduler == "multithreading":
multithreading_scheduler(daskgraph, nodes)
results = multithreading_scheduler(daskgraph, nodes)
else:
raise ValueError("Unknown scheduler")
elif isinstance(scheduler, dict):
with Client(**scheduler) as scheduler:
scheduler.get(daskgraph, nodes)
results = scheduler.get(daskgraph, nodes)
else:
scheduler.get(daskgraph, nodes)
results = scheduler.get(daskgraph, nodes)
return dict(zip(nodes, results))
......@@ -6,6 +6,7 @@ from ewoksdask import execute_graph
from ewokscore.tests.examples.graphs import graph_names
from ewokscore.tests.examples.graphs import get_graph
from ewokscore.tests.utils import assert_taskgraph_result
from ewokscore.tests.utils import assert_workflow_result
from ewokscore import load_graph
logging.getLogger("dask").setLevel(logging.DEBUG)
......@@ -15,16 +16,25 @@ logging.getLogger("ewoksdask").addHandler(logging.StreamHandler(sys.stdout))
@pytest.mark.parametrize(
"graph_name,scheduler",
itertools.product(graph_names(), (None, "multithreading", "multiprocessing")),
"graph_name,scheduler,persist",
itertools.product(
graph_names(), (None, "multithreading", "multiprocessing"), (True, False)
),
)
def test_examples(graph_name, tmpdir, scheduler):
def test_examples(graph_name, tmpdir, scheduler, persist):
graph, expected = get_graph(graph_name)
ewoksgraph = load_graph(graph)
varinfo = {"root_uri": str(tmpdir)}
if persist:
varinfo = {"root_uri": str(tmpdir)}
else:
varinfo = None
if ewoksgraph.is_cyclic or ewoksgraph.has_conditional_links:
with pytest.raises(RuntimeError):
execute_graph(graph, varinfo=varinfo)
execute_graph(graph, scheduler=scheduler, varinfo=varinfo)
else:
execute_graph(graph, varinfo=varinfo)
assert_taskgraph_result(ewoksgraph, expected, varinfo=varinfo)
result = execute_graph(
graph, scheduler=scheduler, varinfo=varinfo, results_of_all_nodes=True
)
if persist:
assert_taskgraph_result(ewoksgraph, expected, varinfo=varinfo)
assert_workflow_result(result, expected, varinfo=varinfo)
......@@ -23,6 +23,7 @@ install_requires =
dask
distributed
dask-jobqueue
bokeh
[options.extras_require]
dev =
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment