diff --git a/ewokscore/__init__.py b/ewokscore/__init__.py index 307f2601d0ff24e518c10021747f53296665a46f..8dc5bb17bd0af1ca613397bf27495de67a8ed567 100644 --- a/ewokscore/__init__.py +++ b/ewokscore/__init__.py @@ -1,6 +1,6 @@ from .task import Task # noqa: F401 from .taskwithprogress import TaskWithProgress # noqa: F401 from .graph import load_graph # noqa: F401 -from .graph import execute_graph # noqa: F401 +from .bindings import execute_graph # noqa: F401 __version__ = "0.0.5-alpha" diff --git a/ewokscore/bindings.py b/ewokscore/bindings.py new file mode 100644 index 0000000000000000000000000000000000000000..fab5e2a905d46d8b01277ba57cea328a0a9fffb6 --- /dev/null +++ b/ewokscore/bindings.py @@ -0,0 +1,14 @@ +from typing import Optional +from .graph import load_graph + + +def execute_graph(graph, load_options: Optional[dict] = None, **execute_options): + """ + :param graph: graph to be executed + :param Optional[dict] load_options: options to provide to the `load_graph` function (and as a consequence to the `TaskGraph.load` as `root_dir`) + :param execute_options: options to provide to the Graph.execute function as `varinfo` or `raise_on_error` + """ + if load_options is None: + load_options = dict() + graph = load_graph(source=graph, **load_options) + return graph.execute(**execute_options) diff --git a/ewokscore/graph.py b/ewokscore/graph.py index 5ae0f30dba2a387cc3ab0727903b0bc85df4a4c4..670d2beba785ea201d25d6ae2e43737618987a13 100644 --- a/ewokscore/graph.py +++ b/ewokscore/graph.py @@ -1,8 +1,9 @@ import os import enum import json +from collections import Counter from collections.abc import Mapping -from typing import Optional, Set +from typing import Dict, Iterable, Optional, Set import networkx from . import inittask @@ -24,18 +25,6 @@ def load_graph(source=None, representation=None, **load_options): return TaskGraph(source=source, representation=representation, **load_options) -def execute_graph(graph, load_options: Optional[dict] = None, **execute_options): - """ - :param graph: graph to be executed - :param Optional[dict] load_options: options to provide to the `load_graph` function (and as a consequence to the `TaskGraph.load` as `root_dir`) - :param execute_options: options to provide to the Graph.execute function as `varinfo` or `raise_on_error` - """ - if load_options is None: - load_options = dict() - graph = load_graph(source=graph, **load_options) - return graph.execute(**execute_options) - - def set_graph_defaults(graph_as_dict): graph_as_dict.setdefault("directed", True) graph_as_dict.setdefault("nodes", list()) @@ -298,8 +287,9 @@ class TaskGraph: def instantiate_task_static( self, node_id: NodeIdType, - tasks: Optional[dict] = None, + tasks: Optional[Dict[Task, int]] = None, varinfo: Optional[dict] = None, + evict_result_counter: Optional[Dict[NodeIdType, int]] = None, ) -> Task: """Instantiate destination task while no access to the dynamic inputs. Side effect: `tasks` will contain all predecessors. @@ -308,20 +298,34 @@ class TaskGraph: raise RuntimeError(f"{self} is cyclic") if tasks is None: tasks = dict() + if evict_result_counter is None: + evict_result_counter = dict() + # Input from previous tasks (instantiate them if needed) dynamic_inputs = dict() - for inputnode in self.predecessors(node_id): - inputtask = tasks.get(inputnode, None) - if inputtask is None: - inputtask = self.instantiate_task_static( - inputnode, tasks=tasks, varinfo=varinfo + for source_node_id in self.predecessors(node_id): + source_task = tasks.get(source_node_id, None) + if source_task is None: + source_task = self.instantiate_task_static( + source_node_id, + tasks=tasks, + varinfo=varinfo, + evict_result_counter=evict_result_counter, ) - link_attrs = self.graph[inputnode][node_id] + link_attrs = self.graph[source_node_id][node_id] inittask.add_dynamic_inputs( - dynamic_inputs, link_attrs, inputtask.output_variables + dynamic_inputs, link_attrs, source_task.output_variables ) - task = self.instantiate_task(node_id, varinfo=varinfo, inputs=dynamic_inputs) - tasks[node_id] = task - return task + # Evict intermediate results + if evict_result_counter: + evict_result_counter[source_node_id] -= 1 + if evict_result_counter[source_node_id] == 0: + tasks.pop(source_node_id) + # Instantiate the requested task + target_task = self.instantiate_task( + node_id, varinfo=varinfo, inputs=dynamic_inputs + ) + tasks[node_id] = target_task + return target_task def successors(self, node_id: NodeIdType, **include_filter): yield from self._iter_downstream_nodes( @@ -625,20 +629,43 @@ class TaskGraph: if linkattrs.get("on_error") and linkattrs.get("conditions"): raise ValueError(err_msg.format("on_error", "conditions")) - def topological_sort(self): + def topological_sort(self) -> Iterable[NodeIdType]: """Sort node names for sequential instantiation+execution of DAGs""" if self.is_cyclic: raise RuntimeError("Sorting nodes is not possible for cyclic graphs") yield from networkx.topological_sort(self.graph) - def execute(self, varinfo: Optional[dict] = None, raise_on_error: bool = True): + def successor_counter(self) -> Dict[NodeIdType, int]: + nsuccessor = Counter() + for edge in self.graph.edges: + nsuccessor[edge[0]] += 1 + return nsuccessor + + def execute( + self, + varinfo: Optional[dict] = None, + raise_on_error: Optional[bool] = True, + results_of_all_nodes: Optional[bool] = False, + ) -> Dict[NodeIdType, Task]: """Sequential execution of DAGs""" if self.is_cyclic: raise RuntimeError("Cannot execute cyclic graphs") if self.has_conditional_links: raise RuntimeError("Cannot execute graphs with conditional links") + if results_of_all_nodes: + evict_result_counter = None + else: + evict_result_counter = self.successor_counter() + cleanup_references = not results_of_all_nodes tasks = dict() - for node in self.topological_sort(): - task = self.instantiate_task_static(node, tasks=tasks, varinfo=varinfo) - task.execute(raise_on_error=raise_on_error) + for node_id in self.topological_sort(): + task = self.instantiate_task_static( + node_id, + tasks=tasks, + varinfo=varinfo, + evict_result_counter=evict_result_counter, + ) + task.execute( + raise_on_error=raise_on_error, cleanup_references=cleanup_references + ) return tasks diff --git a/ewokscore/hashing.py b/ewokscore/hashing.py index 6f83e41d8a41bf8b981a7779059235dffb66c36c..5c87479dc28f554cf82aaf475ef58f6615e4893f 100644 --- a/ewokscore/hashing.py +++ b/ewokscore/hashing.py @@ -237,6 +237,7 @@ class UniversalHashable(HasUhash): return self.__instance_nonce def fix_uhash(self): + """Fix the uhash when it is derived from the uhash data.""" if self.__pre_uhash is not None: return keep, self.__instance_nonce = self.__instance_nonce, None @@ -249,6 +250,15 @@ class UniversalHashable(HasUhash): def undo_fix_uhash(self): self.__pre_uhash = self.__original_pre_uhash + def cleanup_references(self): + """Remove all references to other hashables. + Side effect: fixes the uhash when it depends on another hashable. + """ + if isinstance(self.__pre_uhash, HasUhash): + pre_uhash = self.__pre_uhash.uhash + self.__pre_uhash = pre_uhash + self.__original_pre_uhash = pre_uhash + @property def uhash(self) -> Optional[UniversalHash]: _uhash = self.__pre_uhash diff --git a/ewokscore/task.py b/ewokscore/task.py index 4d2aefc787f1333c8b592cd720c62358c022aae5..c387959cdf8d55c9128401b7f04623f0c1f6ea5b 100644 --- a/ewokscore/task.py +++ b/ewokscore/task.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Union +from typing import Optional, Union from .hashing import UniversalHashable from .variable import VariableContainer @@ -150,6 +150,8 @@ class Task(Registered, UniversalHashable, register=False): @property def input_variables(self): + if self.__inputs is None: + raise RuntimeError("references have been removed") return self.__inputs @property @@ -158,23 +160,23 @@ class Task(Registered, UniversalHashable, register=False): @property def input_uhashes(self): - return self.__inputs.variable_uhashes + return self.input_variables.variable_uhashes @property def input_values(self): - return self.__inputs.variable_values + return self.input_variables.variable_values @property def named_input_values(self): - return self.__inputs.named_variable_values + return self.input_variables.named_variable_values @property def positional_input_values(self): - return self.__inputs.positional_variable_values + return self.input_variables.positional_variable_values @property def npositional_inputs(self): - return self.__inputs.n_positional_variables + return self.input_variables.n_positional_variables @property def output_variables(self): @@ -246,7 +248,7 @@ class Task(Registered, UniversalHashable, register=False): def _iter_missing_input_values(self): for iname in self._INPUT_NAMES: - var = self.__inputs.get(iname) + var = self.input_variables.get(iname) if var is None or not var.has_value: yield iname @@ -265,7 +267,12 @@ class Task(Registered, UniversalHashable, register=False): "The following inputs could not be loaded: " + str(lst) ) - def execute(self, force_rerun=False, raise_on_error=True): + def execute( + self, + force_rerun: Optional[bool] = False, + raise_on_error: Optional[bool] = True, + cleanup_references: Optional[bool] = False, + ): try: if force_rerun: # Rerun a task which is already done @@ -282,6 +289,18 @@ class Task(Registered, UniversalHashable, register=False): raise RuntimeError(f"Task '{self.label}' failed") from e else: self.__succeeded = True + finally: + if cleanup_references: + self.cleanup_references() + + def cleanup_references(self): + """Removes all references to the inputs. + Side effect: fixes the uhash of the task and outputs + """ + self.__inputs = None + self.__public_inputs = None + self.__outputs.cleanup_references() + super().cleanup_references() def run(self): """To be implemented by the derived classes""" diff --git a/ewokscore/tests/test_examples.py b/ewokscore/tests/test_examples.py index 2739e49e50b8e1bd54f91795c64a99864e01c7ec..6f0cc855aa19fd5795bd57cb6b4b4359a56e74ed 100644 --- a/ewokscore/tests/test_examples.py +++ b/ewokscore/tests/test_examples.py @@ -3,6 +3,7 @@ import itertools from .examples.graphs import graph_names from .examples.graphs import get_graph from .utils import assert_taskgraph_result +from .utils import assert_workflow_result from ewokscore import load_graph @@ -20,11 +21,17 @@ def test_execute_graph(graph_name, persist, tmpdir): with pytest.raises(RuntimeError): ewoksgraph.execute(varinfo=varinfo) else: - tasks = ewoksgraph.execute(varinfo=varinfo) + tasks = ewoksgraph.execute(varinfo=varinfo, results_of_all_nodes=True) assert_taskgraph_result(ewoksgraph, expected, tasks=tasks) if persist: assert_taskgraph_result(ewoksgraph, expected, varinfo=varinfo) + end_tasks = ewoksgraph.execute(varinfo=varinfo, results_of_all_nodes=False) + end_nodes = ewoksgraph.end_nodes() + assert end_tasks.keys() == end_nodes + expected = {k: v for k, v in expected.items() if k in end_nodes} + assert_workflow_result(end_tasks, expected, varinfo=varinfo) + def test_graph_cyclic(): g, _ = get_graph("empty") diff --git a/ewokscore/tests/test_hashing.py b/ewokscore/tests/test_hashing.py index a2700581dae6ab51a221ca96737e9d81e243afdb..c9ff93cb0b1c3fccf6d472ef8f909fbebae5da55 100644 --- a/ewokscore/tests/test_hashing.py +++ b/ewokscore/tests/test_hashing.py @@ -1,3 +1,4 @@ +import gc import numpy from ewokscore import hashing @@ -130,3 +131,30 @@ def test_uhash_fixing(): data[0] = 0 uhash = var.uhash assert uhash1org == uhash + + +def test_hashable_cleanup_references(): + class Myclass(hashing.UniversalHashable): + def __init__(self, data, **kw): + self.data = data + super().__init__(**kw) + + def _uhash_data(self): + return self.data + + obj1 = Myclass(10) + nref_start = len(gc.get_referrers(obj1)) + obj2 = Myclass(10, pre_uhash=obj1) + assert len(gc.get_referrers(obj1)) > nref_start + + obj1.data += 1 + assert obj1.uhash == obj2.uhash + + obj2.cleanup_references() + while gc.collect(): + pass + assert len(gc.get_referrers(obj1)) == nref_start + assert obj1.uhash == obj2.uhash + + obj1.data += 1 + assert obj1.uhash != obj2.uhash diff --git a/ewokscore/tests/test_sub_graph.py b/ewokscore/tests/test_sub_graph.py index 18bda7a1c03f0866a29006d807616c2d67e19b49..b3b1875c02059239607fd3df1a2fec30a1e12274 100644 --- a/ewokscore/tests/test_sub_graph.py +++ b/ewokscore/tests/test_sub_graph.py @@ -57,7 +57,7 @@ def test_sub_graph(): } ewoksgraph = load_graph(graph) - tasks = ewoksgraph.execute() + tasks = ewoksgraph.execute(results_of_all_nodes=True) expected = { "node1": {"return_value": 1}, ("node2", ("subnode1", "subsubnode1")): {"return_value": 2}, diff --git a/ewokscore/tests/test_sub_graph_json.py b/ewokscore/tests/test_sub_graph_json.py index bf4f80ab95ce08aafccab5a437778bb1011e685a..d1b71e51d85192a706c2fd813aa8fe940bc52527 100644 --- a/ewokscore/tests/test_sub_graph_json.py +++ b/ewokscore/tests/test_sub_graph_json.py @@ -222,7 +222,7 @@ def graph(tmpdir, subgraph): def test_load_from_json(tmpdir, graph): taskgraph = load_graph(graph, root_dir=str(tmpdir)) - tasks = taskgraph.execute() + tasks = taskgraph.execute(results_of_all_nodes=True) assert len(tasks) == 13 diff --git a/ewokscore/tests/test_task.py b/ewokscore/tests/test_task.py index 586338d09b4eaefb5c9c9aa272dec7ca638bad9d..9db790c00d27b5f5b1e4fab69b5cbdc51d0e09a8 100644 --- a/ewokscore/tests/test_task.py +++ b/ewokscore/tests/test_task.py @@ -1,3 +1,4 @@ +import gc import pytest import json from glob import glob @@ -143,3 +144,33 @@ def test_task_required_positional_inputs(): with pytest.raises(TaskInputError): MyTask() + + +def test_task_cleanup_references(): + class MyTask(Task, input_names=["mylist"], output_names=["mylist"]): + def run(self): + self.outputs.mylist = self.inputs.mylist + [len(self.inputs.mylist)] + + obj = [0, 1, 2] + nref_start = len(gc.get_referrers(obj)) + task1 = MyTask(inputs={"mylist": obj}) + task2 = MyTask(inputs=task1.output_variables) + task1.execute() + task2.execute() + assert len(gc.get_referrers(obj)) > nref_start + + uhash1 = task1.uhash + uhashes1 = task1.output_uhashes + uhash2 = task2.uhash + uhashes2 = task2.output_uhashes + + task1.cleanup_references() + + while gc.collect(): + pass + assert len(gc.get_referrers(obj)) == nref_start + + assert uhash1 == task1.uhash + assert uhashes1 == task1.output_uhashes + assert uhash2 == task2.uhash + assert uhashes2 == task2.output_uhashes diff --git a/ewokscore/tests/test_variable.py b/ewokscore/tests/test_variable.py index 7f9eddb02828b6414a6e877eaa5af69c871f6eff..e6d2e5f1222d39aaa137fb856a1d1cbed0be7347 100644 --- a/ewokscore/tests/test_variable.py +++ b/ewokscore/tests/test_variable.py @@ -1,3 +1,4 @@ +import gc import itertools import pytest from contextlib import contextmanager @@ -293,4 +294,45 @@ def test_variable_container_metadata(scheme, root_uri_type, tmpdir): assert ref_uri["var1"].metadata["@NX_class"] == "NXcollection" assert ref_uri["var1"].metadata["myvalue"] == 888 - print(container["var1"].data_uri) + +def test_variable_cleanup_references(): + obj = [0, 1, 2] + nref_start = len(gc.get_referrers(obj)) + var1 = Variable(obj) + var2 = Variable(pre_uhash=var1) + uhash = var1.uhash + assert uhash == var2.uhash + assert len(gc.get_referrers(obj)) > nref_start + + del var1 + while gc.collect(): + pass + assert len(gc.get_referrers(obj)) > nref_start + + var2.cleanup_references() + while gc.collect(): + pass + assert len(gc.get_referrers(obj)) == nref_start + + assert uhash == var2.uhash + + +def test_variable_container_cleanup_references(): + obj = [0, 1, 2] + nref_start = len(gc.get_referrers(obj)) + var1 = MutableVariableContainer({"myvar": obj}) + var2 = MutableVariableContainer(pre_uhash=var1) + uhash = var1.uhash + assert uhash == var2.uhash + + del var1 + while gc.collect(): + pass + assert len(gc.get_referrers(obj)) > nref_start + + var2.cleanup_references() + while gc.collect(): + pass + assert len(gc.get_referrers(obj)) == nref_start + + assert uhash == var2.uhash diff --git a/ewokscore/tests/utils.py b/ewokscore/tests/utils.py index 176a209294a230c98a230a2074435510a6e8142f..7b156227a5cc27897c9602219d956150dfcd443e 100644 --- a/ewokscore/tests/utils.py +++ b/ewokscore/tests/utils.py @@ -1,11 +1,20 @@ +from typing import Any, Dict, Optional import networkx from pprint import pprint import matplotlib.pyplot as plt from ewokscore import load_graph +from ewokscore.graph import TaskGraph +from ewokscore.node import NodeIdType +from ewokscore.task import Task from ewokscore.variable import value_from_transfer -def assert_taskgraph_result(taskgraph, expected, varinfo=None, tasks=None): +def assert_taskgraph_result( + taskgraph: TaskGraph, + expected: Dict[NodeIdType, Any], + varinfo: Optional[dict] = None, + tasks: Optional[Dict[NodeIdType, Task]] = None, +): taskgraph = load_graph(taskgraph) assert not taskgraph.is_cyclic, "Can only check DAG results" @@ -20,38 +29,47 @@ def assert_taskgraph_result(taskgraph, expected, varinfo=None, tasks=None): assert_task_result(task, node, expected) -def assert_task_result(task, node, expected): - expected_value = expected.get(node) +def assert_task_result(task: Task, node_id: NodeIdType, expected: dict): + expected_value = expected.get(node_id) if expected_value is None: - assert not task.done, node + assert not task.done, node_id else: - assert task.done, node + assert task.done, node_id try: - assert task.output_values == expected_value, node + assert task.output_values == expected_value, node_id except AssertionError: raise except Exception as e: - raise RuntimeError(f"{node} does not have a result") from e + raise RuntimeError(f"{node_id} does not have a result") from e -def assert_workflow_result(results, expected, varinfo=None): +def assert_workflow_result( + results: Dict[NodeIdType, Any], + expected: Dict[NodeIdType, Any], + varinfo: Optional[dict] = None, +): for node_id, expected_result in expected.items(): if expected_result is None: assert node_id not in results continue result = results[node_id] + if isinstance(result, Task): + assert result.done, node_id + result = result.output_values for output_name, expected_value in expected_result.items(): value = result[output_name] assert_result(value, expected_value, varinfo=varinfo) -def assert_workflow_merged_result(result, expected, varinfo=None): +def assert_workflow_merged_result( + result: dict, expected: Dict[NodeIdType, Any], varinfo: Optional[dict] = None +): for output_name, expected_value in expected.items(): value = result[output_name] assert_result(value, expected_value, varinfo=varinfo) -def assert_result(value, expected_value, varinfo=None): +def assert_result(value, expected_value, varinfo: Optional[dict] = None): value = value_from_transfer(value, varinfo=varinfo) assert value == expected_value diff --git a/ewokscore/variable.py b/ewokscore/variable.py index fae084a2a1f1fa831a487b489706d30a3c5f2ab8..12a30fa0fbcc789452d48c62025d96ed5698aa1f 100644 --- a/ewokscore/variable.py +++ b/ewokscore/variable.py @@ -75,8 +75,8 @@ class Variable(hashing.UniversalHashable): super().__init__(pre_uhash=pre_uhash, instance_nonce=instance_nonce) self.value = value - def fixed_uhash_copy(self): - """The uhash of the copy is fixed thereby remove references to other uhashable objects.""" + def copy_without_references(self): + """Copy that does not contain references to uhashable objects""" kwargs = self.get_uhash_init(serialize=True) kwargs["data_proxy"] = self.data_proxy return type(self)(value=self.value, **kwargs) @@ -213,10 +213,26 @@ class VariableContainer(Variable, Mapping): if value: self._update(value) - def fixed_uhash_copy(self): + def fix_uhash(self): + for var in self.values(): + var.fix_uhash() + return super().fix_uhash() + + def cleanup_references(self): + """Remove all references to other hashables. + Side effect: fixes the uhash when it depends on another hashable. + """ + for var in self.values(): + var.cleanup_references() + pre_uhash = self.__varparams.get("pre_uhash") + if isinstance(pre_uhash, hashing.HasUhash): + self.__varparams["pre_uhash"] = pre_uhash.uhash + return super().cleanup_references() + + def copy_without_references(self): """The uhash of the copy is fixed thereby remove references to other uhashable objects.""" return type(self)( - value={name: var.fixed_uhash_copy() for name, var in self.items()}, + value={name: var.copy_without_references() for name, var in self.items()}, **self.__varparams, ) @@ -413,7 +429,7 @@ class VariableContainer(Variable, Mapping): data[name] = var.data_proxy.uri elif var.hashing_enabled: # Remove possible references to a uhashable - data[name] = var.fixed_uhash_copy() + data[name] = var.copy_without_references() else: data[name] = var.value return data