diff --git a/ewoksdask/bindings.py b/ewoksdask/bindings.py index d8fc8d72f88718cdd654740cff855a101a3e1751..87929016710ef682ab583a3ec146bdaaf0df25e3 100644 --- a/ewoksdask/bindings.py +++ b/ewoksdask/bindings.py @@ -3,7 +3,7 @@ https://docs.dask.org/en/latest/scheduler-overview.html """ import json -from typing import Optional +from typing import List, Optional from dask.distributed import Client from dask.threaded import get as multithreading_scheduler from dask.multiprocessing import get as multiprocessing_scheduler @@ -55,7 +55,9 @@ def convert_graph(ewoksgraph, **execute_options): def execute_graph( graph, scheduler=None, + inputs: Optional[List[dict]] = None, results_of_all_nodes: Optional[bool] = False, + outputs: Optional[List[dict]] = None, load_options: Optional[dict] = None, **execute_options ): @@ -66,6 +68,8 @@ def execute_graph( raise RuntimeError("Dask can only execute DAGs") if ewoksgraph.has_conditional_links: raise RuntimeError("Dask cannot handle conditional links") + if inputs: + ewoksgraph.update_default_inputs(inputs) daskgraph = convert_graph(ewoksgraph, **execute_options) if results_of_all_nodes: diff --git a/ewoksdask/tests/test_examples.py b/ewoksdask/tests/test_examples.py index 44d4467e9698cdd8208ca8ef0181ffd7a116a136..1d272a9bb20ac6db4fa6e55c65e409fedab54518 100644 --- a/ewoksdask/tests/test_examples.py +++ b/ewoksdask/tests/test_examples.py @@ -1,40 +1,50 @@ -import sys -import logging import pytest -import itertools + from ewoksdask import execute_graph -from ewokscore.tests.examples.graphs import graph_names -from ewokscore.tests.examples.graphs import get_graph -from ewokscore.tests.utils import assert_taskgraph_result -from ewokscore.tests.utils import assert_workflow_result from ewokscore import load_graph -logging.getLogger("dask").setLevel(logging.DEBUG) -logging.getLogger("dask").addHandler(logging.StreamHandler(sys.stdout)) -logging.getLogger("ewoksdask").setLevel(logging.DEBUG) -logging.getLogger("ewoksdask").addHandler(logging.StreamHandler(sys.stdout)) +from ewokscore.tests.examples.graphs import graph_names +from ewokscore.tests.examples.graphs import get_graph +from ewokscore.tests.utils.results import assert_execute_graph_all_tasks +from ewokscore.tests.utils.results import assert_execute_graph_tasks +from ewokscore.tests.utils.results import filter_expected_results -@pytest.mark.parametrize( - "graph_name,scheduler,persist", - itertools.product( - graph_names(), (None, "multithreading", "multiprocessing"), (True, False) - ), -) -def test_examples(graph_name, tmpdir, scheduler, persist): +@pytest.mark.parametrize("graph_name", graph_names()) +@pytest.mark.parametrize("scheduler", (None, "multithreading", "multiprocessing")) +@pytest.mark.parametrize("scheme", (None, "json")) +def test_examples(graph_name, tmpdir, scheduler, scheme): graph, expected = get_graph(graph_name) ewoksgraph = load_graph(graph) - if persist: - varinfo = {"root_uri": str(tmpdir)} + if scheme: + varinfo = {"root_uri": str(tmpdir), "scheme": scheme} else: varinfo = None if ewoksgraph.is_cyclic or ewoksgraph.has_conditional_links: with pytest.raises(RuntimeError): execute_graph(graph, scheduler=scheduler, varinfo=varinfo) + return + + result = execute_graph( + graph, scheduler=scheduler, varinfo=varinfo, results_of_all_nodes=True + ) + assert_all_results(ewoksgraph, result, expected, varinfo) + result = execute_graph( + graph, scheduler=scheduler, varinfo=varinfo, results_of_all_nodes=False + ) + assert_end_results(ewoksgraph, result, expected, varinfo) + + +def assert_all_results(ewoksgraph, result, expected, varinfo): + if varinfo: + scheme = varinfo.get("scheme") else: - result = execute_graph( - graph, scheduler=scheduler, varinfo=varinfo, results_of_all_nodes=True - ) - if persist: - assert_taskgraph_result(ewoksgraph, expected, varinfo=varinfo) - assert_workflow_result(result, expected, varinfo=varinfo) + scheme = None + if scheme: + assert_execute_graph_all_tasks(ewoksgraph, expected, varinfo=varinfo) + assert_execute_graph_tasks(result, expected, varinfo=varinfo) + + +def assert_end_results(ewoksgraph, result, expected, varinfo): + expected = filter_expected_results(ewoksgraph, expected, end_only=True) + assert_execute_graph_tasks(result, expected, varinfo=varinfo)