Commit 4c96cf04 authored by payno's avatar payno
Browse files

Merge branch '17-ewoks-task-keeps-dataset-in-memory' into 'main'

Resolve "ewoks task keeps dataset in memory"

Closes #17

See merge request !72
parents 2d444508 c82b1fa4
Pipeline #60840 passed with stages
in 2 minutes and 9 seconds
from .task import Task # noqa: F401
from .taskwithprogress import TaskWithProgress # noqa: F401
from .graph import load_graph # noqa: F401
from .graph import execute_graph # noqa: F401
from .bindings import execute_graph # noqa: F401
__version__ = "0.0.5-alpha"
from typing import Optional
from .graph import load_graph
def execute_graph(graph, load_options: Optional[dict] = None, **execute_options):
"""
:param graph: graph to be executed
:param Optional[dict] load_options: options to provide to the `load_graph` function (and as a consequence to the `TaskGraph.load` as `root_dir`)
:param execute_options: options to provide to the Graph.execute function as `varinfo` or `raise_on_error`
"""
if load_options is None:
load_options = dict()
graph = load_graph(source=graph, **load_options)
return graph.execute(**execute_options)
import os
import enum
import json
from collections import Counter
from collections.abc import Mapping
from typing import Optional, Set
from typing import Dict, Iterable, Optional, Set
import networkx
from . import inittask
......@@ -24,18 +25,6 @@ def load_graph(source=None, representation=None, **load_options):
return TaskGraph(source=source, representation=representation, **load_options)
def execute_graph(graph, load_options: Optional[dict] = None, **execute_options):
"""
:param graph: graph to be executed
:param Optional[dict] load_options: options to provide to the `load_graph` function (and as a consequence to the `TaskGraph.load` as `root_dir`)
:param execute_options: options to provide to the Graph.execute function as `varinfo` or `raise_on_error`
"""
if load_options is None:
load_options = dict()
graph = load_graph(source=graph, **load_options)
return graph.execute(**execute_options)
def set_graph_defaults(graph_as_dict):
graph_as_dict.setdefault("directed", True)
graph_as_dict.setdefault("nodes", list())
......@@ -298,8 +287,9 @@ class TaskGraph:
def instantiate_task_static(
self,
node_id: NodeIdType,
tasks: Optional[dict] = None,
tasks: Optional[Dict[Task, int]] = None,
varinfo: Optional[dict] = None,
evict_result_counter: Optional[Dict[NodeIdType, int]] = None,
) -> Task:
"""Instantiate destination task while no access to the dynamic inputs.
Side effect: `tasks` will contain all predecessors.
......@@ -308,20 +298,34 @@ class TaskGraph:
raise RuntimeError(f"{self} is cyclic")
if tasks is None:
tasks = dict()
if evict_result_counter is None:
evict_result_counter = dict()
# Input from previous tasks (instantiate them if needed)
dynamic_inputs = dict()
for inputnode in self.predecessors(node_id):
inputtask = tasks.get(inputnode, None)
if inputtask is None:
inputtask = self.instantiate_task_static(
inputnode, tasks=tasks, varinfo=varinfo
for source_node_id in self.predecessors(node_id):
source_task = tasks.get(source_node_id, None)
if source_task is None:
source_task = self.instantiate_task_static(
source_node_id,
tasks=tasks,
varinfo=varinfo,
evict_result_counter=evict_result_counter,
)
link_attrs = self.graph[inputnode][node_id]
link_attrs = self.graph[source_node_id][node_id]
inittask.add_dynamic_inputs(
dynamic_inputs, link_attrs, inputtask.output_variables
dynamic_inputs, link_attrs, source_task.output_variables
)
task = self.instantiate_task(node_id, varinfo=varinfo, inputs=dynamic_inputs)
tasks[node_id] = task
return task
# Evict intermediate results
if evict_result_counter:
evict_result_counter[source_node_id] -= 1
if evict_result_counter[source_node_id] == 0:
tasks.pop(source_node_id)
# Instantiate the requested task
target_task = self.instantiate_task(
node_id, varinfo=varinfo, inputs=dynamic_inputs
)
tasks[node_id] = target_task
return target_task
def successors(self, node_id: NodeIdType, **include_filter):
yield from self._iter_downstream_nodes(
......@@ -625,20 +629,43 @@ class TaskGraph:
if linkattrs.get("on_error") and linkattrs.get("conditions"):
raise ValueError(err_msg.format("on_error", "conditions"))
def topological_sort(self):
def topological_sort(self) -> Iterable[NodeIdType]:
"""Sort node names for sequential instantiation+execution of DAGs"""
if self.is_cyclic:
raise RuntimeError("Sorting nodes is not possible for cyclic graphs")
yield from networkx.topological_sort(self.graph)
def execute(self, varinfo: Optional[dict] = None, raise_on_error: bool = True):
def successor_counter(self) -> Dict[NodeIdType, int]:
nsuccessor = Counter()
for edge in self.graph.edges:
nsuccessor[edge[0]] += 1
return nsuccessor
def execute(
self,
varinfo: Optional[dict] = None,
raise_on_error: Optional[bool] = True,
results_of_all_nodes: Optional[bool] = False,
) -> Dict[NodeIdType, Task]:
"""Sequential execution of DAGs"""
if self.is_cyclic:
raise RuntimeError("Cannot execute cyclic graphs")
if self.has_conditional_links:
raise RuntimeError("Cannot execute graphs with conditional links")
if results_of_all_nodes:
evict_result_counter = None
else:
evict_result_counter = self.successor_counter()
cleanup_references = not results_of_all_nodes
tasks = dict()
for node in self.topological_sort():
task = self.instantiate_task_static(node, tasks=tasks, varinfo=varinfo)
task.execute(raise_on_error=raise_on_error)
for node_id in self.topological_sort():
task = self.instantiate_task_static(
node_id,
tasks=tasks,
varinfo=varinfo,
evict_result_counter=evict_result_counter,
)
task.execute(
raise_on_error=raise_on_error, cleanup_references=cleanup_references
)
return tasks
......@@ -237,6 +237,7 @@ class UniversalHashable(HasUhash):
return self.__instance_nonce
def fix_uhash(self):
"""Fix the uhash when it is derived from the uhash data."""
if self.__pre_uhash is not None:
return
keep, self.__instance_nonce = self.__instance_nonce, None
......@@ -249,6 +250,15 @@ class UniversalHashable(HasUhash):
def undo_fix_uhash(self):
self.__pre_uhash = self.__original_pre_uhash
def cleanup_references(self):
"""Remove all references to other hashables.
Side effect: fixes the uhash when it depends on another hashable.
"""
if isinstance(self.__pre_uhash, HasUhash):
pre_uhash = self.__pre_uhash.uhash
self.__pre_uhash = pre_uhash
self.__original_pre_uhash = pre_uhash
@property
def uhash(self) -> Optional[UniversalHash]:
_uhash = self.__pre_uhash
......
from collections.abc import Mapping
from typing import Union
from typing import Optional, Union
from .hashing import UniversalHashable
from .variable import VariableContainer
......@@ -150,6 +150,8 @@ class Task(Registered, UniversalHashable, register=False):
@property
def input_variables(self):
if self.__inputs is None:
raise RuntimeError("references have been removed")
return self.__inputs
@property
......@@ -158,23 +160,23 @@ class Task(Registered, UniversalHashable, register=False):
@property
def input_uhashes(self):
return self.__inputs.variable_uhashes
return self.input_variables.variable_uhashes
@property
def input_values(self):
return self.__inputs.variable_values
return self.input_variables.variable_values
@property
def named_input_values(self):
return self.__inputs.named_variable_values
return self.input_variables.named_variable_values
@property
def positional_input_values(self):
return self.__inputs.positional_variable_values
return self.input_variables.positional_variable_values
@property
def npositional_inputs(self):
return self.__inputs.n_positional_variables
return self.input_variables.n_positional_variables
@property
def output_variables(self):
......@@ -246,7 +248,7 @@ class Task(Registered, UniversalHashable, register=False):
def _iter_missing_input_values(self):
for iname in self._INPUT_NAMES:
var = self.__inputs.get(iname)
var = self.input_variables.get(iname)
if var is None or not var.has_value:
yield iname
......@@ -265,7 +267,12 @@ class Task(Registered, UniversalHashable, register=False):
"The following inputs could not be loaded: " + str(lst)
)
def execute(self, force_rerun=False, raise_on_error=True):
def execute(
self,
force_rerun: Optional[bool] = False,
raise_on_error: Optional[bool] = True,
cleanup_references: Optional[bool] = False,
):
try:
if force_rerun:
# Rerun a task which is already done
......@@ -282,6 +289,18 @@ class Task(Registered, UniversalHashable, register=False):
raise RuntimeError(f"Task '{self.label}' failed") from e
else:
self.__succeeded = True
finally:
if cleanup_references:
self.cleanup_references()
def cleanup_references(self):
"""Removes all references to the inputs.
Side effect: fixes the uhash of the task and outputs
"""
self.__inputs = None
self.__public_inputs = None
self.__outputs.cleanup_references()
super().cleanup_references()
def run(self):
"""To be implemented by the derived classes"""
......
......@@ -3,6 +3,7 @@ import itertools
from .examples.graphs import graph_names
from .examples.graphs import get_graph
from .utils import assert_taskgraph_result
from .utils import assert_workflow_result
from ewokscore import load_graph
......@@ -20,11 +21,17 @@ def test_execute_graph(graph_name, persist, tmpdir):
with pytest.raises(RuntimeError):
ewoksgraph.execute(varinfo=varinfo)
else:
tasks = ewoksgraph.execute(varinfo=varinfo)
tasks = ewoksgraph.execute(varinfo=varinfo, results_of_all_nodes=True)
assert_taskgraph_result(ewoksgraph, expected, tasks=tasks)
if persist:
assert_taskgraph_result(ewoksgraph, expected, varinfo=varinfo)
end_tasks = ewoksgraph.execute(varinfo=varinfo, results_of_all_nodes=False)
end_nodes = ewoksgraph.end_nodes()
assert end_tasks.keys() == end_nodes
expected = {k: v for k, v in expected.items() if k in end_nodes}
assert_workflow_result(end_tasks, expected, varinfo=varinfo)
def test_graph_cyclic():
g, _ = get_graph("empty")
......
import gc
import numpy
from ewokscore import hashing
......@@ -130,3 +131,30 @@ def test_uhash_fixing():
data[0] = 0
uhash = var.uhash
assert uhash1org == uhash
def test_hashable_cleanup_references():
class Myclass(hashing.UniversalHashable):
def __init__(self, data, **kw):
self.data = data
super().__init__(**kw)
def _uhash_data(self):
return self.data
obj1 = Myclass(10)
nref_start = len(gc.get_referrers(obj1))
obj2 = Myclass(10, pre_uhash=obj1)
assert len(gc.get_referrers(obj1)) > nref_start
obj1.data += 1
assert obj1.uhash == obj2.uhash
obj2.cleanup_references()
while gc.collect():
pass
assert len(gc.get_referrers(obj1)) == nref_start
assert obj1.uhash == obj2.uhash
obj1.data += 1
assert obj1.uhash != obj2.uhash
......@@ -57,7 +57,7 @@ def test_sub_graph():
}
ewoksgraph = load_graph(graph)
tasks = ewoksgraph.execute()
tasks = ewoksgraph.execute(results_of_all_nodes=True)
expected = {
"node1": {"return_value": 1},
("node2", ("subnode1", "subsubnode1")): {"return_value": 2},
......
......@@ -222,7 +222,7 @@ def graph(tmpdir, subgraph):
def test_load_from_json(tmpdir, graph):
taskgraph = load_graph(graph, root_dir=str(tmpdir))
tasks = taskgraph.execute()
tasks = taskgraph.execute(results_of_all_nodes=True)
assert len(tasks) == 13
......
import gc
import pytest
import json
from glob import glob
......@@ -143,3 +144,33 @@ def test_task_required_positional_inputs():
with pytest.raises(TaskInputError):
MyTask()
def test_task_cleanup_references():
class MyTask(Task, input_names=["mylist"], output_names=["mylist"]):
def run(self):
self.outputs.mylist = self.inputs.mylist + [len(self.inputs.mylist)]
obj = [0, 1, 2]
nref_start = len(gc.get_referrers(obj))
task1 = MyTask(inputs={"mylist": obj})
task2 = MyTask(inputs=task1.output_variables)
task1.execute()
task2.execute()
assert len(gc.get_referrers(obj)) > nref_start
uhash1 = task1.uhash
uhashes1 = task1.output_uhashes
uhash2 = task2.uhash
uhashes2 = task2.output_uhashes
task1.cleanup_references()
while gc.collect():
pass
assert len(gc.get_referrers(obj)) == nref_start
assert uhash1 == task1.uhash
assert uhashes1 == task1.output_uhashes
assert uhash2 == task2.uhash
assert uhashes2 == task2.output_uhashes
import gc
import itertools
import pytest
from contextlib import contextmanager
......@@ -293,4 +294,45 @@ def test_variable_container_metadata(scheme, root_uri_type, tmpdir):
assert ref_uri["var1"].metadata["@NX_class"] == "NXcollection"
assert ref_uri["var1"].metadata["myvalue"] == 888
print(container["var1"].data_uri)
def test_variable_cleanup_references():
obj = [0, 1, 2]
nref_start = len(gc.get_referrers(obj))
var1 = Variable(obj)
var2 = Variable(pre_uhash=var1)
uhash = var1.uhash
assert uhash == var2.uhash
assert len(gc.get_referrers(obj)) > nref_start
del var1
while gc.collect():
pass
assert len(gc.get_referrers(obj)) > nref_start
var2.cleanup_references()
while gc.collect():
pass
assert len(gc.get_referrers(obj)) == nref_start
assert uhash == var2.uhash
def test_variable_container_cleanup_references():
obj = [0, 1, 2]
nref_start = len(gc.get_referrers(obj))
var1 = MutableVariableContainer({"myvar": obj})
var2 = MutableVariableContainer(pre_uhash=var1)
uhash = var1.uhash
assert uhash == var2.uhash
del var1
while gc.collect():
pass
assert len(gc.get_referrers(obj)) > nref_start
var2.cleanup_references()
while gc.collect():
pass
assert len(gc.get_referrers(obj)) == nref_start
assert uhash == var2.uhash
from typing import Any, Dict, Optional
import networkx
from pprint import pprint
import matplotlib.pyplot as plt
from ewokscore import load_graph
from ewokscore.graph import TaskGraph
from ewokscore.node import NodeIdType
from ewokscore.task import Task
from ewokscore.variable import value_from_transfer
def assert_taskgraph_result(taskgraph, expected, varinfo=None, tasks=None):
def assert_taskgraph_result(
taskgraph: TaskGraph,
expected: Dict[NodeIdType, Any],
varinfo: Optional[dict] = None,
tasks: Optional[Dict[NodeIdType, Task]] = None,
):
taskgraph = load_graph(taskgraph)
assert not taskgraph.is_cyclic, "Can only check DAG results"
......@@ -20,38 +29,47 @@ def assert_taskgraph_result(taskgraph, expected, varinfo=None, tasks=None):
assert_task_result(task, node, expected)
def assert_task_result(task, node, expected):
expected_value = expected.get(node)
def assert_task_result(task: Task, node_id: NodeIdType, expected: dict):
expected_value = expected.get(node_id)
if expected_value is None:
assert not task.done, node
assert not task.done, node_id
else:
assert task.done, node
assert task.done, node_id
try:
assert task.output_values == expected_value, node
assert task.output_values == expected_value, node_id
except AssertionError:
raise
except Exception as e:
raise RuntimeError(f"{node} does not have a result") from e
raise RuntimeError(f"{node_id} does not have a result") from e
def assert_workflow_result(results, expected, varinfo=None):
def assert_workflow_result(
results: Dict[NodeIdType, Any],
expected: Dict[NodeIdType, Any],
varinfo: Optional[dict] = None,
):
for node_id, expected_result in expected.items():
if expected_result is None:
assert node_id not in results
continue
result = results[node_id]
if isinstance(result, Task):
assert result.done, node_id
result = result.output_values
for output_name, expected_value in expected_result.items():
value = result[output_name]
assert_result(value, expected_value, varinfo=varinfo)
def assert_workflow_merged_result(result, expected, varinfo=None):
def assert_workflow_merged_result(
result: dict, expected: Dict[NodeIdType, Any], varinfo: Optional[dict] = None
):
for output_name, expected_value in expected.items():
value = result[output_name]
assert_result(value, expected_value, varinfo=varinfo)
def assert_result(value, expected_value, varinfo=None):
def assert_result(value, expected_value, varinfo: Optional[dict] = None):
value = value_from_transfer(value, varinfo=varinfo)
assert value == expected_value
......
......@@ -75,8 +75,8 @@ class Variable(hashing.UniversalHashable):
super().__init__(pre_uhash=pre_uhash, instance_nonce=instance_nonce)
self.value = value
def fixed_uhash_copy(self):
"""The uhash of the copy is fixed thereby remove references to other uhashable objects."""
def copy_without_references(self):
"""Copy that does not contain references to uhashable objects"""
kwargs = self.get_uhash_init(serialize=True)
kwargs["data_proxy"] = self.data_proxy
return type(self)(value=self.value, **kwargs)
......@@ -213,10 +213,26 @@ class VariableContainer(Variable, Mapping):
if value:
self._update(value)
def fixed_uhash_copy(self):
def fix_uhash(self):
for var in self.values():
var.fix_uhash()
return super().fix_uhash()
def cleanup_references(self):
"""Remove all references to other hashables.
Side effect: fixes the uhash when it depends on another hashable.
"""
for var in self.values():
var.cleanup_references()
pre_uhash = self.__varparams.get("pre_uhash")
if isinstance(pre_uhash, hashing.HasUhash):
self.__varparams["pre_uhash"] = pre_uhash.uhash
return super().cleanup_references()
def copy_without_references(self):
"""The uhash of the copy is fixed thereby remove references to other uhashable objects."""
return type(self)(
value={name: var.fixed_uhash_copy() for name, var in self.items()},
value={name: var.copy_without_references() for name, var in self.items()},
**self.__varparams,
)
......@@ -413,7 +429,7 @@ class VariableContainer(Variable, Mapping):
data[name] = var.data_proxy.uri
elif var.hashing_enabled:
# Remove possible references to a uhashable
data[name] = var.fixed_uhash_copy()
data[name] = var.copy_without_references()
else:
data[name] = var.value
return data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment