Skip to content
Snippets Groups Projects
Commit 3f31bbd5 authored by Loic Huder's avatar Loic Huder
Browse files

Replace `inputs_complete` by `start_node` to explicitly mark nodes as start nodes

parent 375fee02
No related branches found
No related tags found
No related merge requests found
Pipeline #193737 failed
This commit is part of merge request !236. Comments created here will be created in the context of that merge request.
...@@ -151,8 +151,8 @@ Node attributes ...@@ -151,8 +151,8 @@ Node attributes
{ {
"default_inputs": [{"name":"a", "value":1}] "default_inputs": [{"name":"a", "value":1}]
} }
* *inputs_complete* (optional): set to `True` when the default inputs cover all required inputs * *start_node* (optional): when set to `True`, the node will be explicitly defined as a start node i.e. a node that should be executed before all others.
(used for *method*, *script* and *notebook* as their required inputs are unknown) (To be used as an escape hatch when the graph analysis fails to correctly assert the start nodes).
* *conditions_else_value* (optional): value used in conditional links to indicate the *else* value (Default: `None`) * *conditions_else_value* (optional): value used in conditional links to indicate the *else* value (Default: `None`)
* *default_error_node* (optional): when set to `True` all nodes without error handler will be linked to this node. * *default_error_node* (optional): when set to `True` all nodes without error handler will be linked to this node.
* *default_error_attributes* (optional): when `default_error_node=True` this dictionary is used as attributes for the * *default_error_attributes* (optional): when `default_error_node=True` this dictionary is used as attributes for the
......
...@@ -257,12 +257,11 @@ def has_required_predecessors(graph: networkx.DiGraph, node_id: NodeIdType) -> b ...@@ -257,12 +257,11 @@ def has_required_predecessors(graph: networkx.DiGraph, node_id: NodeIdType) -> b
def has_required_static_inputs(graph: networkx.DiGraph, node_id: NodeIdType) -> bool: def has_required_static_inputs(graph: networkx.DiGraph, node_id: NodeIdType) -> bool:
"""Returns True when the default inputs cover all required inputs.""" """Returns True when the default inputs cover all required inputs."""
node_attrs = graph.nodes[node_id] node_attrs = graph.nodes[node_id]
inputs_complete = node_attrs.get("inputs_complete", None) if node_attrs.get("task_type", None) != "class":
if isinstance(inputs_complete, bool): # Tasks that are not `class` (e.g. `method` and `script`)
# method and script tasks always have an empty `required_input_names` # always have an empty `required_input_names`
# although they may have required input. This keyword is used the # although they may have required input.
# manually indicate that all required inputs are statically provided. return False
return inputs_complete
taskclass = get_task_class(node_id, node_attrs) taskclass = get_task_class(node_id, node_attrs)
static_inputs = {d["name"] for d in node_attrs.get("default_inputs", list())} static_inputs = {d["name"] for d in node_attrs.get("default_inputs", list())}
return not (set(taskclass.required_input_names()) - static_inputs) return not (set(taskclass.required_input_names()) - static_inputs)
...@@ -297,13 +296,23 @@ def node_has_noncovered_conditions( ...@@ -297,13 +296,23 @@ def node_has_noncovered_conditions(
return False return False
def node_is_start_node(graph: networkx.DiGraph, node_id: NodeIdType) -> bool:
node = graph.nodes[node_id]
if node.get("start_node", False):
return True
return not node_has_predecessors(graph, node_id)
def start_nodes(graph: networkx.DiGraph) -> Set[NodeIdType]: def start_nodes(graph: networkx.DiGraph) -> Set[NodeIdType]:
"""Nodes from which the graph execution starts""" """Nodes from which the graph execution starts"""
nodes = set(
node_id for node_id in graph.nodes if not node_has_predecessors(graph, node_id) start_nodes: Set[NodeIdType] = set(
node_id for node_id in graph.nodes if node_is_start_node(graph, node_id)
) )
if nodes: if start_nodes:
return nodes return start_nodes
return set( return set(
node_id node_id
for node_id in graph.nodes for node_id in graph.nodes
......
...@@ -112,7 +112,7 @@ def _get_subnode_attributes( ...@@ -112,7 +112,7 @@ def _get_subnode_attributes(
"""Update all input node attributes of the subgraph with the graph node attributes from the super graph""" """Update all input node attributes of the subgraph with the graph node attributes from the super graph"""
transfer_attributes = { transfer_attributes = {
"default_inputs", "default_inputs",
"inputs_complete", "start_node",
"conditions_else_value", "conditions_else_value",
"default_error_node", "default_error_node",
} }
......
import pytest
import os.path
from ewokscore import Task
from ewoks import execute_graph
def test_explicit_start_node_on_self_triggering_node(tmpdir):
"""
Workflow:
- LOOP depends on itself (self-triggering task).
- SAVE depends on LOOP and CONFIG
- CONFIG does not depend on anything
/|
/ |
LOOP -- SAVE
/
/
/
CONFIG
This test tests that, while graph analysis considers that the only
start node is CONFIG, it is possible to set `start_node` for LOOP
and make it a start node so that SAVE can get inputs from LOOP and
CONFIG.
"""
class ConfigTask(Task, input_names=["filename"], output_names=["config"]):
def run(self):
self.outputs.config = {"filename": self.inputs.filename}
class LoopTask(
Task,
input_names=["i", "n"],
output_names=["i", "keep_looping"],
):
def run(self):
self.outputs.i = self.inputs.i + 1
self.outputs.keep_looping = self.outputs.i < self.inputs.n
class SaveTask(
Task,
input_names=["config"],
optional_input_names=["i"],
output_names=["result"],
):
def run(self):
if self.missing_inputs.i:
raise RuntimeError("LOOP not executed!")
config = self.inputs.config
with open(config["filename"], "a") as out_file:
out_file.write(f"LOOP executed: {self.inputs.i}")
workflow = {
"graph": {"id": "testworkflow"},
"nodes": [
{
"id": "CONFIG",
"task_type": "class",
"task_identifier": "ewokscore.tests.test_start_node.ConfigTask",
},
{
"id": "LOOP",
"task_type": "class",
"task_identifier": "ewokscore.tests.test_start_node.LoopTask",
},
{
"id": "SAVE",
"task_type": "class",
"task_identifier": "ewokscore.tests.test_start_node.SaveTask",
},
],
"links": [
{
"source": "LOOP",
"target": "LOOP",
"data_mapping": [{"source_output": "i", "target_input": "i"}],
"conditions": [{"source_output": "keep_looping", "value": True}],
},
{
"source": "LOOP",
"target": "SAVE",
"data_mapping": [{"source_output": "i", "target_input": "i"}],
},
{
"source": "CONFIG",
"target": "SAVE",
"data_mapping": [{"source_output": "config", "target_input": "config"}],
},
],
}
filename = str(tmpdir / "test.txt")
max_iterations = 10
inputs = [
{
"task_identifier": "ewokscore.tests.test_start_node.ConfigTask",
"name": "filename",
"value": filename,
},
{
"task_identifier": "ewokscore.tests.test_start_node.LoopTask",
"name": "i",
"value": 0,
},
{
"task_identifier": "ewokscore.tests.test_start_node.LoopTask",
"name": "n",
"value": max_iterations,
},
]
# Execute without setting `start_node` in the LOOP node: the SAVE task fails.
with pytest.raises(RuntimeError) as e:
execute_graph(
workflow,
engine="ppf",
scaling_workers=False,
pool_type="thread",
inputs=inputs,
)
assert str(e) == "RuntimeError: Task 'SAVE' failed"
assert not os.path.exists(filename)
# Execute with setting `start_node` in the LOOP node; the SAVE task runs and the file exists.
loop_node = workflow["nodes"][1]
assert loop_node["id"] == "LOOP"
loop_node["start_node"] = True
execute_graph(
workflow,
engine="ppf",
scaling_workers=False,
pool_type="thread",
inputs=inputs,
)
with open(filename) as _file:
txt = _file.read()
for i in range(max_iterations):
assert str(i) in txt
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment