Commit 3dfd0ca1 authored by Wout De Nolf's avatar Wout De Nolf
Browse files

option to keep only end-node results when executing a graph

parent 74f8e3c0
Pipeline #59898 passed with stages
in 41 seconds
import os import os
import enum import enum
import json import json
from collections import Counter
from collections.abc import Mapping from collections.abc import Mapping
from typing import Optional, Set from typing import Dict, Iterable, Optional, Set
import networkx import networkx
from . import inittask from . import inittask
...@@ -286,8 +287,9 @@ class TaskGraph: ...@@ -286,8 +287,9 @@ class TaskGraph:
def instantiate_task_static( def instantiate_task_static(
self, self,
node_id: NodeIdType, node_id: NodeIdType,
tasks: Optional[dict] = None, tasks: Optional[Dict[Task, int]] = None,
varinfo: Optional[dict] = None, varinfo: Optional[dict] = None,
evict_result_counter: Optional[Dict[NodeIdType, int]] = None,
) -> Task: ) -> Task:
"""Instantiate destination task while no access to the dynamic inputs. """Instantiate destination task while no access to the dynamic inputs.
Side effect: `tasks` will contain all predecessors. Side effect: `tasks` will contain all predecessors.
...@@ -296,20 +298,34 @@ class TaskGraph: ...@@ -296,20 +298,34 @@ class TaskGraph:
raise RuntimeError(f"{self} is cyclic") raise RuntimeError(f"{self} is cyclic")
if tasks is None: if tasks is None:
tasks = dict() tasks = dict()
if evict_result_counter is None:
evict_result_counter = dict()
# Input from previous tasks (instantiate them if needed)
dynamic_inputs = dict() dynamic_inputs = dict()
for inputnode in self.predecessors(node_id): for source_node_id in self.predecessors(node_id):
inputtask = tasks.get(inputnode, None) source_task = tasks.get(source_node_id, None)
if inputtask is None: if source_task is None:
inputtask = self.instantiate_task_static( source_task = self.instantiate_task_static(
inputnode, tasks=tasks, varinfo=varinfo source_node_id,
tasks=tasks,
varinfo=varinfo,
evict_result_counter=evict_result_counter,
) )
link_attrs = self.graph[inputnode][node_id] link_attrs = self.graph[source_node_id][node_id]
inittask.add_dynamic_inputs( inittask.add_dynamic_inputs(
dynamic_inputs, link_attrs, inputtask.output_variables dynamic_inputs, link_attrs, source_task.output_variables
) )
task = self.instantiate_task(node_id, varinfo=varinfo, inputs=dynamic_inputs) # Evict intermediate results
tasks[node_id] = task if evict_result_counter:
return task evict_result_counter[source_node_id] -= 1
if evict_result_counter[source_node_id] == 0:
tasks.pop(source_node_id)
# Instantiate the requested task
target_task = self.instantiate_task(
node_id, varinfo=varinfo, inputs=dynamic_inputs
)
tasks[node_id] = target_task
return target_task
def successors(self, node_id: NodeIdType, **include_filter): def successors(self, node_id: NodeIdType, **include_filter):
yield from self._iter_downstream_nodes( yield from self._iter_downstream_nodes(
...@@ -613,20 +629,40 @@ class TaskGraph: ...@@ -613,20 +629,40 @@ class TaskGraph:
if linkattrs.get("on_error") and linkattrs.get("conditions"): if linkattrs.get("on_error") and linkattrs.get("conditions"):
raise ValueError(err_msg.format("on_error", "conditions")) raise ValueError(err_msg.format("on_error", "conditions"))
def topological_sort(self): def topological_sort(self) -> Iterable[NodeIdType]:
"""Sort node names for sequential instantiation+execution of DAGs""" """Sort node names for sequential instantiation+execution of DAGs"""
if self.is_cyclic: if self.is_cyclic:
raise RuntimeError("Sorting nodes is not possible for cyclic graphs") raise RuntimeError("Sorting nodes is not possible for cyclic graphs")
yield from networkx.topological_sort(self.graph) yield from networkx.topological_sort(self.graph)
def execute(self, varinfo: Optional[dict] = None, raise_on_error: bool = True): def successor_counter(self) -> Dict[NodeIdType, int]:
nsuccessor = Counter()
for edge in self.graph.edges:
nsuccessor[edge[0]] += 1
return nsuccessor
def execute(
self,
varinfo: Optional[dict] = None,
raise_on_error: Optional[bool] = True,
results_of_all_nodes: Optional[bool] = False,
) -> Dict[NodeIdType, Task]:
"""Sequential execution of DAGs""" """Sequential execution of DAGs"""
if self.is_cyclic: if self.is_cyclic:
raise RuntimeError("Cannot execute cyclic graphs") raise RuntimeError("Cannot execute cyclic graphs")
if self.has_conditional_links: if self.has_conditional_links:
raise RuntimeError("Cannot execute graphs with conditional links") raise RuntimeError("Cannot execute graphs with conditional links")
if results_of_all_nodes:
evict_result_counter = None
else:
evict_result_counter = self.successor_counter()
tasks = dict() tasks = dict()
for node in self.topological_sort(): for node_id in self.topological_sort():
task = self.instantiate_task_static(node, tasks=tasks, varinfo=varinfo) task = self.instantiate_task_static(
node_id,
tasks=tasks,
varinfo=varinfo,
evict_result_counter=evict_result_counter,
)
task.execute(raise_on_error=raise_on_error) task.execute(raise_on_error=raise_on_error)
return tasks return tasks
from collections.abc import Mapping from collections.abc import Mapping
from typing import Union from typing import Optional, Union
from .hashing import UniversalHashable from .hashing import UniversalHashable
from .variable import VariableContainer from .variable import VariableContainer
...@@ -261,7 +261,9 @@ class Task(Registered, UniversalHashable, register=False): ...@@ -261,7 +261,9 @@ class Task(Registered, UniversalHashable, register=False):
"The following inputs could not be loaded: " + str(lst) "The following inputs could not be loaded: " + str(lst)
) )
def execute(self, force_rerun=False, raise_on_error=True): def execute(
self, force_rerun: Optional[bool] = False, raise_on_error: Optional[bool] = True
):
try: try:
if force_rerun: if force_rerun:
# Rerun a task which is already done # Rerun a task which is already done
......
...@@ -3,6 +3,7 @@ import itertools ...@@ -3,6 +3,7 @@ import itertools
from .examples.graphs import graph_names from .examples.graphs import graph_names
from .examples.graphs import get_graph from .examples.graphs import get_graph
from .utils import assert_taskgraph_result from .utils import assert_taskgraph_result
from .utils import assert_workflow_result
from ewokscore import load_graph from ewokscore import load_graph
...@@ -20,11 +21,17 @@ def test_execute_graph(graph_name, persist, tmpdir): ...@@ -20,11 +21,17 @@ def test_execute_graph(graph_name, persist, tmpdir):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
ewoksgraph.execute(varinfo=varinfo) ewoksgraph.execute(varinfo=varinfo)
else: else:
tasks = ewoksgraph.execute(varinfo=varinfo) tasks = ewoksgraph.execute(varinfo=varinfo, results_of_all_nodes=True)
assert_taskgraph_result(ewoksgraph, expected, tasks=tasks) assert_taskgraph_result(ewoksgraph, expected, tasks=tasks)
if persist: if persist:
assert_taskgraph_result(ewoksgraph, expected, varinfo=varinfo) assert_taskgraph_result(ewoksgraph, expected, varinfo=varinfo)
end_tasks = ewoksgraph.execute(varinfo=varinfo, results_of_all_nodes=False)
end_nodes = ewoksgraph.end_nodes()
assert end_tasks.keys() == end_nodes
expected = {k: v for k, v in expected.items() if k in end_nodes}
assert_workflow_result(end_tasks, expected, varinfo=varinfo)
def test_graph_cyclic(): def test_graph_cyclic():
g, _ = get_graph("empty") g, _ = get_graph("empty")
......
...@@ -57,7 +57,7 @@ def test_sub_graph(): ...@@ -57,7 +57,7 @@ def test_sub_graph():
} }
ewoksgraph = load_graph(graph) ewoksgraph = load_graph(graph)
tasks = ewoksgraph.execute() tasks = ewoksgraph.execute(results_of_all_nodes=True)
expected = { expected = {
"node1": {"return_value": 1}, "node1": {"return_value": 1},
("node2", ("subnode1", "subsubnode1")): {"return_value": 2}, ("node2", ("subnode1", "subsubnode1")): {"return_value": 2},
......
...@@ -222,7 +222,7 @@ def graph(tmpdir, subgraph): ...@@ -222,7 +222,7 @@ def graph(tmpdir, subgraph):
def test_load_from_json(tmpdir, graph): def test_load_from_json(tmpdir, graph):
taskgraph = load_graph(graph, root_dir=str(tmpdir)) taskgraph = load_graph(graph, root_dir=str(tmpdir))
tasks = taskgraph.execute() tasks = taskgraph.execute(results_of_all_nodes=True)
assert len(tasks) == 13 assert len(tasks) == 13
......
from typing import Any, Dict, Optional
import networkx import networkx
from pprint import pprint from pprint import pprint
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from ewokscore import load_graph from ewokscore import load_graph
from ewokscore.graph import TaskGraph
from ewokscore.node import NodeIdType
from ewokscore.task import Task
from ewokscore.variable import value_from_transfer from ewokscore.variable import value_from_transfer
def assert_taskgraph_result(taskgraph, expected, varinfo=None, tasks=None): def assert_taskgraph_result(
taskgraph: TaskGraph,
expected: Dict[NodeIdType, Any],
varinfo: Optional[dict] = None,
tasks: Optional[Dict[NodeIdType, Task]] = None,
):
taskgraph = load_graph(taskgraph) taskgraph = load_graph(taskgraph)
assert not taskgraph.is_cyclic, "Can only check DAG results" assert not taskgraph.is_cyclic, "Can only check DAG results"
...@@ -20,38 +29,47 @@ def assert_taskgraph_result(taskgraph, expected, varinfo=None, tasks=None): ...@@ -20,38 +29,47 @@ def assert_taskgraph_result(taskgraph, expected, varinfo=None, tasks=None):
assert_task_result(task, node, expected) assert_task_result(task, node, expected)
def assert_task_result(task, node, expected): def assert_task_result(task: Task, node_id: NodeIdType, expected: dict):
expected_value = expected.get(node) expected_value = expected.get(node_id)
if expected_value is None: if expected_value is None:
assert not task.done, node assert not task.done, node_id
else: else:
assert task.done, node assert task.done, node_id
try: try:
assert task.output_values == expected_value, node assert task.output_values == expected_value, node_id
except AssertionError: except AssertionError:
raise raise
except Exception as e: except Exception as e:
raise RuntimeError(f"{node} does not have a result") from e raise RuntimeError(f"{node_id} does not have a result") from e
def assert_workflow_result(results, expected, varinfo=None): def assert_workflow_result(
results: Dict[NodeIdType, Any],
expected: Dict[NodeIdType, Any],
varinfo: Optional[dict] = None,
):
for node_id, expected_result in expected.items(): for node_id, expected_result in expected.items():
if expected_result is None: if expected_result is None:
assert node_id not in results assert node_id not in results
continue continue
result = results[node_id] result = results[node_id]
if isinstance(result, Task):
assert result.done, node_id
result = result.output_values
for output_name, expected_value in expected_result.items(): for output_name, expected_value in expected_result.items():
value = result[output_name] value = result[output_name]
assert_result(value, expected_value, varinfo=varinfo) assert_result(value, expected_value, varinfo=varinfo)
def assert_workflow_merged_result(result, expected, varinfo=None): def assert_workflow_merged_result(
result: dict, expected: Dict[NodeIdType, Any], varinfo: Optional[dict] = None
):
for output_name, expected_value in expected.items(): for output_name, expected_value in expected.items():
value = result[output_name] value = result[output_name]
assert_result(value, expected_value, varinfo=varinfo) assert_result(value, expected_value, varinfo=varinfo)
def assert_result(value, expected_value, varinfo=None): def assert_result(value, expected_value, varinfo: Optional[dict] = None):
value = value_from_transfer(value, varinfo=varinfo) value = value_from_transfer(value, varinfo=varinfo)
assert value == expected_value assert value == expected_value
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment