Commit 032d8611 authored by Wout De Nolf's avatar Wout De Nolf
Browse files

refactor the conditional link else value

parent 548537d6
Pipeline #63621 passed with stages
in 48 seconds
......@@ -150,6 +150,7 @@ Node attributes
}
* *inputs_complete* (optional): set to `True` when the default input covers all required input
(used for method and script as the required inputs are unknown)
* *conditions_else_value* (optional): value used in conditional links to indicate the *else* value (`None` by default)
Link attributes
^^^^^^^^^^^^^^^
......
......@@ -2,7 +2,7 @@ import os
import enum
import json
import yaml
from collections import Counter
from collections import Counter, defaultdict
from collections.abc import Mapping
from typing import Any, Dict, Iterable, List, Optional, Set, Union
import networkx
......@@ -19,8 +19,6 @@ from .node import node_id_from_json
from .node import get_node_label
from . import hashing
CONDITIONS_ELSE_VALUE = "__other__"
def load_graph(source=None, representation=None, **load_options):
if isinstance(source, TaskGraph):
......@@ -562,6 +560,7 @@ class TaskGraph:
)
def end_nodes(self) -> Set[NodeIdType]:
"""Node that could potentially be the end of a graph execution thread"""
nodes = set(
node_id for node_id in self.graph.nodes if not self.has_successors(node_id)
)
......@@ -570,67 +569,34 @@ class TaskGraph:
return set(
node_id
for node_id in self.graph.nodes
if self._node_has_noncovered_conditions(node_id)
if self.node_has_noncovered_conditions(node_id)
)
def _node_has_noncovered_conditions(self, source_id: NodeIdType) -> bool:
links = self._get_node_expanded_conditions(source_id)
has_complement = [False] * len(links)
default_complements = {CONDITIONS_ELSE_VALUE}
def node_condition_values(self, source_id: NodeIdType) -> Dict[str, set]:
condition_values = defaultdict(set)
for target_id in self.successors(source_id, link_has_conditions=True):
for condition in self.graph[source_id][target_id]["conditions"]:
varname = condition["source_output"]
value = condition["value"]
condition_values[varname].add(value)
return condition_values
def node_has_noncovered_conditions(self, source_id: NodeIdType) -> bool:
conditions_else_value = self.graph.nodes[source_id].get(
"conditions_else_value", None
)
complements = {
CONDITIONS_ELSE_VALUE: None,
True: {False, CONDITIONS_ELSE_VALUE},
False: {True, CONDITIONS_ELSE_VALUE},
True: {False, conditions_else_value},
False: {True, conditions_else_value},
}
for i, conditions1 in enumerate(links):
if has_complement[i]:
continue
for j in range(i + 1, len(links)):
conditions2 = links[j]
if self._conditions_are_complementary(
conditions1, conditions2, default_complements, complements
):
has_complement[i] = True
has_complement[j] = True
break
if not has_complement[i]:
return True
condition_values = self.node_condition_values(source_id)
for values in condition_values.values():
for value in values:
cvalue = complements.get(value, conditions_else_value)
if cvalue not in values:
return True
return False
@staticmethod
def _conditions_are_complementary(
conditions1, conditions2, default_complements, complements
):
for varname, value in conditions1.items():
complementary_values = complements.get(value, default_complements)
if complementary_values is None:
# Any value is complementary
continue
if conditions2[varname] not in complementary_values:
return False
return True
def _get_node_expanded_conditions(self, source_id: NodeIdType):
"""Ensure that conditional link starting from a source node has
the same set of output names.
"""
links = [
self.graph[source_id][target_id]["conditions"]
for target_id in self.successors(source_id, link_has_conditions=True)
]
all_condition_names = {
item["source_output"] for conditions in links for item in conditions
}
for conditions in links:
link_condition_names = {item["source_output"] for item in conditions}
for name in all_condition_names - link_condition_names:
conditions.append(
{"source_output": name, "value": CONDITIONS_ELSE_VALUE}
)
return links
def validate_graph(self):
for node_id, node_attrs in self.graph.nodes.items():
inittask.validate_task_executable(node_attrs, node_id=node_id)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment