Commit d9903145 authored by Wout De Nolf's avatar Wout De Nolf
Browse files

introduce a shared state for all tasks of a graph

parent 7cbd01a3
Pipeline #51062 passed with stages
in 43 seconds
...@@ -226,31 +226,40 @@ class TaskGraph: ...@@ -226,31 +226,40 @@ class TaskGraph:
return True return True
return False return False
def instantiate_task(self, node_name, varinfo=None, inputs=None): def instantiate_task(self, node_name, inputs=None, varinfo=None, shared_state=None):
"""Named arguments are dynamic input and Variable config. """Named arguments are dynamic input and Variable config.
Static input from the persistent representation are Static input from the persistent representation are
added internally. added internally.
:param str node_name: :param str node_name:
:param dict or None tasks: keeps upstream tasks :param dict or None inputs: optional dynamic inputs
:param **inputs: dynamic inputs :param dict or None varinfo:
:param dict or None shared_state:
:returns Task: :returns Task:
""" """
# Dynamic input has priority over static input # Dynamic input has priority over static input
nodeattrs = self.graph.nodes[node_name] nodeattrs = self.graph.nodes[node_name]
return inittask.instantiate_task( return inittask.instantiate_task(
nodeattrs, node_name=node_name, varinfo=varinfo, inputs=inputs nodeattrs,
node_name=node_name,
inputs=inputs,
varinfo=varinfo,
shared_state=shared_state,
) )
def instantiate_task_static(self, node_name, tasks=None, varinfo=None, inputs=None): def instantiate_task_static(
self, node_name, inputs=None, varinfo=None, shared_state=None, tasks=None
):
"""Instantiate destination task while no or partial access to the dynamic """Instantiate destination task while no or partial access to the dynamic
inputs or their identifiers. Side effect: `tasks` will contain all predecessors. inputs or their identifiers. Side effect: `tasks` will contain all predecessors.
Remark: Only works for DAGs. Remark: Only works for DAGs.
:param str node_name: :param str node_name:
:param dict or None tasks: keeps upstream tasks
:param dict or None inputs: optional dynamic inputs :param dict or None inputs: optional dynamic inputs
:param dict or None varinfo:
:param dict or None shared_state:
:param dict or None tasks: keeps upstream tasks
:returns Task: :returns Task:
""" """
if self.is_cyclic: if self.is_cyclic:
...@@ -262,7 +271,7 @@ class TaskGraph: ...@@ -262,7 +271,7 @@ class TaskGraph:
inputtask = tasks.get(inputnode, None) inputtask = tasks.get(inputnode, None)
if inputtask is None: if inputtask is None:
inputtask = self.instantiate_task_static( inputtask = self.instantiate_task_static(
inputnode, tasks=tasks, varinfo=varinfo inputnode, varinfo=varinfo, shared_state=shared_state, tasks=tasks
) )
link_attrs = self.graph[inputnode][node_name] link_attrs = self.graph[inputnode][node_name]
all_arguments = link_attrs.get("all_arguments", False) all_arguments = link_attrs.get("all_arguments", False)
...@@ -284,7 +293,9 @@ class TaskGraph: ...@@ -284,7 +293,9 @@ class TaskGraph:
dynamic_inputs[to_arg] = inputtask.output_variables dynamic_inputs[to_arg] = inputtask.output_variables
if inputs: if inputs:
dynamic_inputs.update(inputs) dynamic_inputs.update(inputs)
task = self.instantiate_task(node_name, varinfo=varinfo, inputs=dynamic_inputs) task = self.instantiate_task(
node_name, inputs=dynamic_inputs, varinfo=varinfo, shared_state=shared_state
)
tasks[node_name] = task tasks[node_name] = task
return task return task
...@@ -599,7 +610,7 @@ class TaskGraph: ...@@ -599,7 +610,7 @@ class TaskGraph:
raise RuntimeError("Sorting nodes is not possible for cyclic graphs") raise RuntimeError("Sorting nodes is not possible for cyclic graphs")
yield from networkx.topological_sort(self.graph) yield from networkx.topological_sort(self.graph)
def execute(self, varinfo=None): def execute(self, varinfo=None, shared_state=None):
"""Sequential execution of DAGs""" """Sequential execution of DAGs"""
if self.is_cyclic: if self.is_cyclic:
raise RuntimeError("Cannot execute cyclic graphs") raise RuntimeError("Cannot execute cyclic graphs")
...@@ -607,6 +618,8 @@ class TaskGraph: ...@@ -607,6 +618,8 @@ class TaskGraph:
raise RuntimeError("Cannot execute graphs with conditional links") raise RuntimeError("Cannot execute graphs with conditional links")
tasks = dict() tasks = dict()
for node in self.topological_sort(): for node in self.topological_sort():
task = self.instantiate_task_static(node, tasks=tasks, varinfo=varinfo) task = self.instantiate_task_static(
node, tasks=tasks, varinfo=varinfo, shared_state=shared_state
)
task.execute() task.execute()
return tasks return tasks
...@@ -59,12 +59,15 @@ def validate_task_executable(node_attrs, node_name="", all=False): ...@@ -59,12 +59,15 @@ def validate_task_executable(node_attrs, node_name="", all=False):
task_executable_key(node_attrs, node_name=node_name, all=all) task_executable_key(node_attrs, node_name=node_name, all=all)
def instantiate_task(node_attrs, varinfo=None, inputs=None, node_name=""): def instantiate_task(
node_attrs, node_name="", inputs=None, varinfo=None, shared_state=None
):
""" """
:param dict node_attrs: node attributes of the graph representation :param dict node_attrs: node attributes of the graph representation
:param dict varinfo: `Variable` constructor arguments
:param dict or None inputs: dynamic inputs (from other tasks)
:param str node_name: :param str node_name:
:param dict or None inputs: dynamic inputs (from other tasks)
:param dict or None varinfo: `Variable` constructor arguments
:param dict or None shared_state:
:returns Task: :returns Task:
""" """
# Static inputs # Static inputs
...@@ -76,19 +79,29 @@ def instantiate_task(node_attrs, varinfo=None, inputs=None, node_name=""): ...@@ -76,19 +79,29 @@ def instantiate_task(node_attrs, varinfo=None, inputs=None, node_name=""):
# Instantiate task # Instantiate task
key, value = task_executable_key(node_attrs, node_name=node_name) key, value = task_executable_key(node_attrs, node_name=node_name)
if key == "class": if key == "class":
return Task.instantiate(value, inputs=task_inputs, varinfo=varinfo) return Task.instantiate(
value, inputs=task_inputs, varinfo=varinfo, shared_state=shared_state
)
elif key == "method": elif key == "method":
task_inputs["method"] = value task_inputs["method"] = value
return MethodExecutorTask(inputs=task_inputs, varinfo=varinfo) return MethodExecutorTask(
inputs=task_inputs, varinfo=varinfo, shared_state=shared_state
)
elif key == "ppfmethod": elif key == "ppfmethod":
task_inputs["method"] = value task_inputs["method"] = value
return PpfMethodExecutorTask(inputs=task_inputs, varinfo=varinfo) return PpfMethodExecutorTask(
inputs=task_inputs, varinfo=varinfo, shared_state=shared_state
)
elif key == "ppfport": elif key == "ppfport":
task_inputs["ppfport"] = value task_inputs["ppfport"] = value
return PpfPortTask(inputs=task_inputs, varinfo=varinfo) return PpfPortTask(
inputs=task_inputs, varinfo=varinfo, shared_state=shared_state
)
elif key == "script": elif key == "script":
task_inputs["script"] = value task_inputs["script"] = value
return ScriptExecutorTask(inputs=task_inputs, varinfo=varinfo) return ScriptExecutorTask(
inputs=task_inputs, varinfo=varinfo, shared_state=shared_state
)
else: else:
raise_task_error(node_name, all=False) raise_task_error(node_name, all=False)
......
...@@ -28,7 +28,7 @@ class Task(Registered, hashing.UniversalHashable, register=False): ...@@ -28,7 +28,7 @@ class Task(Registered, hashing.UniversalHashable, register=False):
_OUTPUT_NAMES = set() _OUTPUT_NAMES = set()
_N_REQUIRED_POSITIONAL_INPUTS = 0 _N_REQUIRED_POSITIONAL_INPUTS = 0
def __init__(self, inputs=None, varinfo=None): def __init__(self, inputs=None, varinfo=None, shared_state=None):
"""The named arguments are inputs and Variable configuration""" """The named arguments are inputs and Variable configuration"""
if inputs is None: if inputs is None:
inputs = dict() inputs = dict()
...@@ -59,6 +59,9 @@ class Task(Registered, hashing.UniversalHashable, register=False): ...@@ -59,6 +59,9 @@ class Task(Registered, hashing.UniversalHashable, register=False):
# Misc # Misc
self._exception = None self._exception = None
self._done = None self._done = None
if shared_state is None:
shared_state = dict()
self.shared_state = shared_state
# The output hash will update dynamically if any of the input # The output hash will update dynamically if any of the input
# variables change # variables change
......
...@@ -77,3 +77,22 @@ def test_wrong_argument_definitions(): ...@@ -77,3 +77,22 @@ def test_wrong_argument_definitions():
links[0]["required"] = True links[0]["required"] = True
with pytest.raises(ValueError): with pytest.raises(ValueError):
load_graph(graph) load_graph(graph)
def test_shared_state():
nodes = [
{"id": "a", "method": "builtins.float", "inputs": {0: 0}},
{"id": "b", "method": "builtins.float"},
{"id": "c", "method": "builtins.float"},
{"id": "d", "method": "builtins.float"},
]
links = [
{"source": "a", "target": "b", "arguments": {0: "return_value"}},
{"source": "b", "target": "c", "arguments": {0: "return_value"}},
{"source": "c", "target": "d", "arguments": {0: "return_value"}},
]
ewoksgraph = load_graph({"nodes": nodes, "links": links})
shared_state = dict()
tasks = ewoksgraph.execute(shared_state=shared_state)
for task in tasks.values():
assert task.shared_state is shared_state
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment