Commit 3dfd0ca1 authored by Wout De Nolf's avatar Wout De Nolf
Browse files

option to keep only end-node results when executing a graph

parent 74f8e3c0
Pipeline #59898 passed with stages
in 41 seconds
import os
import enum
import json
from collections import Counter
from collections.abc import Mapping
from typing import Optional, Set
from typing import Dict, Iterable, Optional, Set
import networkx
from . import inittask
......@@ -286,8 +287,9 @@ class TaskGraph:
def instantiate_task_static(
self,
node_id: NodeIdType,
tasks: Optional[dict] = None,
tasks: Optional[Dict[Task, int]] = None,
varinfo: Optional[dict] = None,
evict_result_counter: Optional[Dict[NodeIdType, int]] = None,
) -> Task:
"""Instantiate destination task while no access to the dynamic inputs.
Side effect: `tasks` will contain all predecessors.
......@@ -296,20 +298,34 @@ class TaskGraph:
raise RuntimeError(f"{self} is cyclic")
if tasks is None:
tasks = dict()
if evict_result_counter is None:
evict_result_counter = dict()
# Input from previous tasks (instantiate them if needed)
dynamic_inputs = dict()
for inputnode in self.predecessors(node_id):
inputtask = tasks.get(inputnode, None)
if inputtask is None:
inputtask = self.instantiate_task_static(
inputnode, tasks=tasks, varinfo=varinfo
for source_node_id in self.predecessors(node_id):
source_task = tasks.get(source_node_id, None)
if source_task is None:
source_task = self.instantiate_task_static(
source_node_id,
tasks=tasks,
varinfo=varinfo,
evict_result_counter=evict_result_counter,
)
link_attrs = self.graph[inputnode][node_id]
link_attrs = self.graph[source_node_id][node_id]
inittask.add_dynamic_inputs(
dynamic_inputs, link_attrs, inputtask.output_variables
dynamic_inputs, link_attrs, source_task.output_variables
)
task = self.instantiate_task(node_id, varinfo=varinfo, inputs=dynamic_inputs)
tasks[node_id] = task
return task
# Evict intermediate results
if evict_result_counter:
evict_result_counter[source_node_id] -= 1
if evict_result_counter[source_node_id] == 0:
tasks.pop(source_node_id)
# Instantiate the requested task
target_task = self.instantiate_task(
node_id, varinfo=varinfo, inputs=dynamic_inputs
)
tasks[node_id] = target_task
return target_task
def successors(self, node_id: NodeIdType, **include_filter):
yield from self._iter_downstream_nodes(
......@@ -613,20 +629,40 @@ class TaskGraph:
if linkattrs.get("on_error") and linkattrs.get("conditions"):
raise ValueError(err_msg.format("on_error", "conditions"))
def topological_sort(self):
def topological_sort(self) -> Iterable[NodeIdType]:
"""Sort node names for sequential instantiation+execution of DAGs"""
if self.is_cyclic:
raise RuntimeError("Sorting nodes is not possible for cyclic graphs")
yield from networkx.topological_sort(self.graph)
def execute(self, varinfo: Optional[dict] = None, raise_on_error: bool = True):
def successor_counter(self) -> Dict[NodeIdType, int]:
nsuccessor = Counter()
for edge in self.graph.edges:
nsuccessor[edge[0]] += 1
return nsuccessor
def execute(
self,
varinfo: Optional[dict] = None,
raise_on_error: Optional[bool] = True,
results_of_all_nodes: Optional[bool] = False,
) -> Dict[NodeIdType, Task]:
"""Sequential execution of DAGs"""
if self.is_cyclic:
raise RuntimeError("Cannot execute cyclic graphs")
if self.has_conditional_links:
raise RuntimeError("Cannot execute graphs with conditional links")
if results_of_all_nodes:
evict_result_counter = None
else:
evict_result_counter = self.successor_counter()
tasks = dict()
for node in self.topological_sort():
task = self.instantiate_task_static(node, tasks=tasks, varinfo=varinfo)
for node_id in self.topological_sort():
task = self.instantiate_task_static(
node_id,
tasks=tasks,
varinfo=varinfo,
evict_result_counter=evict_result_counter,
)
task.execute(raise_on_error=raise_on_error)
return tasks
from collections.abc import Mapping
from typing import Union
from typing import Optional, Union
from .hashing import UniversalHashable
from .variable import VariableContainer
......@@ -261,7 +261,9 @@ class Task(Registered, UniversalHashable, register=False):
"The following inputs could not be loaded: " + str(lst)
)
def execute(self, force_rerun=False, raise_on_error=True):
def execute(
self, force_rerun: Optional[bool] = False, raise_on_error: Optional[bool] = True
):
try:
if force_rerun:
# Rerun a task which is already done
......
......@@ -3,6 +3,7 @@ import itertools
from .examples.graphs import graph_names
from .examples.graphs import get_graph
from .utils import assert_taskgraph_result
from .utils import assert_workflow_result
from ewokscore import load_graph
......@@ -20,11 +21,17 @@ def test_execute_graph(graph_name, persist, tmpdir):
with pytest.raises(RuntimeError):
ewoksgraph.execute(varinfo=varinfo)
else:
tasks = ewoksgraph.execute(varinfo=varinfo)
tasks = ewoksgraph.execute(varinfo=varinfo, results_of_all_nodes=True)
assert_taskgraph_result(ewoksgraph, expected, tasks=tasks)
if persist:
assert_taskgraph_result(ewoksgraph, expected, varinfo=varinfo)
end_tasks = ewoksgraph.execute(varinfo=varinfo, results_of_all_nodes=False)
end_nodes = ewoksgraph.end_nodes()
assert end_tasks.keys() == end_nodes
expected = {k: v for k, v in expected.items() if k in end_nodes}
assert_workflow_result(end_tasks, expected, varinfo=varinfo)
def test_graph_cyclic():
g, _ = get_graph("empty")
......
......@@ -57,7 +57,7 @@ def test_sub_graph():
}
ewoksgraph = load_graph(graph)
tasks = ewoksgraph.execute()
tasks = ewoksgraph.execute(results_of_all_nodes=True)
expected = {
"node1": {"return_value": 1},
("node2", ("subnode1", "subsubnode1")): {"return_value": 2},
......
......@@ -222,7 +222,7 @@ def graph(tmpdir, subgraph):
def test_load_from_json(tmpdir, graph):
taskgraph = load_graph(graph, root_dir=str(tmpdir))
tasks = taskgraph.execute()
tasks = taskgraph.execute(results_of_all_nodes=True)
assert len(tasks) == 13
......
from typing import Any, Dict, Optional
import networkx
from pprint import pprint
import matplotlib.pyplot as plt
from ewokscore import load_graph
from ewokscore.graph import TaskGraph
from ewokscore.node import NodeIdType
from ewokscore.task import Task
from ewokscore.variable import value_from_transfer
def assert_taskgraph_result(taskgraph, expected, varinfo=None, tasks=None):
def assert_taskgraph_result(
taskgraph: TaskGraph,
expected: Dict[NodeIdType, Any],
varinfo: Optional[dict] = None,
tasks: Optional[Dict[NodeIdType, Task]] = None,
):
taskgraph = load_graph(taskgraph)
assert not taskgraph.is_cyclic, "Can only check DAG results"
......@@ -20,38 +29,47 @@ def assert_taskgraph_result(taskgraph, expected, varinfo=None, tasks=None):
assert_task_result(task, node, expected)
def assert_task_result(task, node, expected):
expected_value = expected.get(node)
def assert_task_result(task: Task, node_id: NodeIdType, expected: dict):
expected_value = expected.get(node_id)
if expected_value is None:
assert not task.done, node
assert not task.done, node_id
else:
assert task.done, node
assert task.done, node_id
try:
assert task.output_values == expected_value, node
assert task.output_values == expected_value, node_id
except AssertionError:
raise
except Exception as e:
raise RuntimeError(f"{node} does not have a result") from e
raise RuntimeError(f"{node_id} does not have a result") from e
def assert_workflow_result(results, expected, varinfo=None):
def assert_workflow_result(
results: Dict[NodeIdType, Any],
expected: Dict[NodeIdType, Any],
varinfo: Optional[dict] = None,
):
for node_id, expected_result in expected.items():
if expected_result is None:
assert node_id not in results
continue
result = results[node_id]
if isinstance(result, Task):
assert result.done, node_id
result = result.output_values
for output_name, expected_value in expected_result.items():
value = result[output_name]
assert_result(value, expected_value, varinfo=varinfo)
def assert_workflow_merged_result(result, expected, varinfo=None):
def assert_workflow_merged_result(
result: dict, expected: Dict[NodeIdType, Any], varinfo: Optional[dict] = None
):
for output_name, expected_value in expected.items():
value = result[output_name]
assert_result(value, expected_value, varinfo=varinfo)
def assert_result(value, expected_value, varinfo=None):
def assert_result(value, expected_value, varinfo: Optional[dict] = None):
value = value_from_transfer(value, varinfo=varinfo)
assert value == expected_value
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment