Commit d9903145 authored by Wout De Nolf's avatar Wout De Nolf
Browse files

introduce a shared state for all tasks of a graph

parent 7cbd01a3
Pipeline #51062 passed with stages
in 43 seconds
......@@ -226,31 +226,40 @@ class TaskGraph:
return True
return False
def instantiate_task(self, node_name, varinfo=None, inputs=None):
def instantiate_task(self, node_name, inputs=None, varinfo=None, shared_state=None):
"""Named arguments are dynamic input and Variable config.
Static input from the persistent representation are
added internally.
:param str node_name:
:param dict or None tasks: keeps upstream tasks
:param **inputs: dynamic inputs
:param dict or None inputs: optional dynamic inputs
:param dict or None varinfo:
:param dict or None shared_state:
:returns Task:
"""
# Dynamic input has priority over static input
nodeattrs = self.graph.nodes[node_name]
return inittask.instantiate_task(
nodeattrs, node_name=node_name, varinfo=varinfo, inputs=inputs
nodeattrs,
node_name=node_name,
inputs=inputs,
varinfo=varinfo,
shared_state=shared_state,
)
def instantiate_task_static(self, node_name, tasks=None, varinfo=None, inputs=None):
def instantiate_task_static(
self, node_name, inputs=None, varinfo=None, shared_state=None, tasks=None
):
"""Instantiate destination task while no or partial access to the dynamic
inputs or their identifiers. Side effect: `tasks` will contain all predecessors.
Remark: Only works for DAGs.
:param str node_name:
:param dict or None tasks: keeps upstream tasks
:param dict or None inputs: optional dynamic inputs
:param dict or None varinfo:
:param dict or None shared_state:
:param dict or None tasks: keeps upstream tasks
:returns Task:
"""
if self.is_cyclic:
......@@ -262,7 +271,7 @@ class TaskGraph:
inputtask = tasks.get(inputnode, None)
if inputtask is None:
inputtask = self.instantiate_task_static(
inputnode, tasks=tasks, varinfo=varinfo
inputnode, varinfo=varinfo, shared_state=shared_state, tasks=tasks
)
link_attrs = self.graph[inputnode][node_name]
all_arguments = link_attrs.get("all_arguments", False)
......@@ -284,7 +293,9 @@ class TaskGraph:
dynamic_inputs[to_arg] = inputtask.output_variables
if inputs:
dynamic_inputs.update(inputs)
task = self.instantiate_task(node_name, varinfo=varinfo, inputs=dynamic_inputs)
task = self.instantiate_task(
node_name, inputs=dynamic_inputs, varinfo=varinfo, shared_state=shared_state
)
tasks[node_name] = task
return task
......@@ -599,7 +610,7 @@ class TaskGraph:
raise RuntimeError("Sorting nodes is not possible for cyclic graphs")
yield from networkx.topological_sort(self.graph)
def execute(self, varinfo=None):
def execute(self, varinfo=None, shared_state=None):
"""Sequential execution of DAGs"""
if self.is_cyclic:
raise RuntimeError("Cannot execute cyclic graphs")
......@@ -607,6 +618,8 @@ class TaskGraph:
raise RuntimeError("Cannot execute graphs with conditional links")
tasks = dict()
for node in self.topological_sort():
task = self.instantiate_task_static(node, tasks=tasks, varinfo=varinfo)
task = self.instantiate_task_static(
node, tasks=tasks, varinfo=varinfo, shared_state=shared_state
)
task.execute()
return tasks
......@@ -59,12 +59,15 @@ def validate_task_executable(node_attrs, node_name="", all=False):
task_executable_key(node_attrs, node_name=node_name, all=all)
def instantiate_task(node_attrs, varinfo=None, inputs=None, node_name=""):
def instantiate_task(
node_attrs, node_name="", inputs=None, varinfo=None, shared_state=None
):
"""
:param dict node_attrs: node attributes of the graph representation
:param dict varinfo: `Variable` constructor arguments
:param dict or None inputs: dynamic inputs (from other tasks)
:param str node_name:
:param dict or None inputs: dynamic inputs (from other tasks)
:param dict or None varinfo: `Variable` constructor arguments
:param dict or None shared_state:
:returns Task:
"""
# Static inputs
......@@ -76,19 +79,29 @@ def instantiate_task(node_attrs, varinfo=None, inputs=None, node_name=""):
# Instantiate task
key, value = task_executable_key(node_attrs, node_name=node_name)
if key == "class":
return Task.instantiate(value, inputs=task_inputs, varinfo=varinfo)
return Task.instantiate(
value, inputs=task_inputs, varinfo=varinfo, shared_state=shared_state
)
elif key == "method":
task_inputs["method"] = value
return MethodExecutorTask(inputs=task_inputs, varinfo=varinfo)
return MethodExecutorTask(
inputs=task_inputs, varinfo=varinfo, shared_state=shared_state
)
elif key == "ppfmethod":
task_inputs["method"] = value
return PpfMethodExecutorTask(inputs=task_inputs, varinfo=varinfo)
return PpfMethodExecutorTask(
inputs=task_inputs, varinfo=varinfo, shared_state=shared_state
)
elif key == "ppfport":
task_inputs["ppfport"] = value
return PpfPortTask(inputs=task_inputs, varinfo=varinfo)
return PpfPortTask(
inputs=task_inputs, varinfo=varinfo, shared_state=shared_state
)
elif key == "script":
task_inputs["script"] = value
return ScriptExecutorTask(inputs=task_inputs, varinfo=varinfo)
return ScriptExecutorTask(
inputs=task_inputs, varinfo=varinfo, shared_state=shared_state
)
else:
raise_task_error(node_name, all=False)
......
......@@ -28,7 +28,7 @@ class Task(Registered, hashing.UniversalHashable, register=False):
_OUTPUT_NAMES = set()
_N_REQUIRED_POSITIONAL_INPUTS = 0
def __init__(self, inputs=None, varinfo=None):
def __init__(self, inputs=None, varinfo=None, shared_state=None):
"""The named arguments are inputs and Variable configuration"""
if inputs is None:
inputs = dict()
......@@ -59,6 +59,9 @@ class Task(Registered, hashing.UniversalHashable, register=False):
# Misc
self._exception = None
self._done = None
if shared_state is None:
shared_state = dict()
self.shared_state = shared_state
# The output hash will update dynamically if any of the input
# variables change
......
......@@ -77,3 +77,22 @@ def test_wrong_argument_definitions():
links[0]["required"] = True
with pytest.raises(ValueError):
load_graph(graph)
def test_shared_state():
nodes = [
{"id": "a", "method": "builtins.float", "inputs": {0: 0}},
{"id": "b", "method": "builtins.float"},
{"id": "c", "method": "builtins.float"},
{"id": "d", "method": "builtins.float"},
]
links = [
{"source": "a", "target": "b", "arguments": {0: "return_value"}},
{"source": "b", "target": "c", "arguments": {0: "return_value"}},
{"source": "c", "target": "d", "arguments": {0: "return_value"}},
]
ewoksgraph = load_graph({"nodes": nodes, "links": links})
shared_state = dict()
tasks = ewoksgraph.execute(shared_state=shared_state)
for task in tasks.values():
assert task.shared_state is shared_state
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment