graph.py 23.5 KB
Newer Older
1
import os
2
3
import enum
import json
4
from collections.abc import Mapping
5
from typing import Optional, Set
6
import networkx
7

8
9
10
from . import inittask
from .utils import qualname
from .utils import dict_merge
11
from .subgraph import extract_graph_nodes
12
from .subgraph import add_subgraph_links
13
14
15
from .task import Task
from .node import NodeIdType
from .node import node_id_from_json
16
17
18
19

CONDITIONS_ELSE_VALUE = "__other__"


20
def load_graph(source=None, representation=None, **load_options):
21
22
23
    if isinstance(source, TaskGraph):
        return source
    else:
24
        return TaskGraph(source=source, representation=representation, **load_options)
25
26
27
28
29
30
31
32


def set_graph_defaults(graph_as_dict):
    graph_as_dict.setdefault("directed", True)
    graph_as_dict.setdefault("nodes", list())
    graph_as_dict.setdefault("links", list())


33
def node_has_links(graph, node_id):
34
    try:
35
        next(graph.successors(node_id))
36
37
    except StopIteration:
        try:
38
            next(graph.predecessors(node_id))
39
40
41
42
43
        except StopIteration:
            return False
    return True


44
def merge_graphs(graphs, graph_attrs=None, rename_nodes=None, **load_options):
45
46
47
48
49
50
    lst = list()
    if rename_nodes is None:
        rename_nodes = [True] * len(graphs)
    else:
        assert len(graphs) == len(rename_nodes)
    for g, rename in zip(graphs, rename_nodes):
51
        g = load_graph(g, **load_options)
52
53
54
55
56
57
        gname = repr(g)
        g = g.graph
        if rename:
            mapping = {s: (gname, s) for s in g.nodes}
            g = networkx.relabel_nodes(g, mapping, copy=True)
        lst.append(g)
58
    ret = load_graph(networkx.compose_all(lst), **load_options)
59
60
    if graph_attrs:
        ret.graph.graph.update(graph_attrs)
61
62
63
    return ret


64
65
def flatten_multigraph(graph: networkx.DiGraph) -> networkx.DiGraph:
    """The attributes of links between the same two nodes are merged."""
66
67
68
69
70
71
72
73
    if not graph.is_multigraph():
        return graph
    newgraph = networkx.DiGraph(**graph.graph)

    edgeattrs = dict()
    for edge, attrs in graph.edges.items():
        key = edge[:2]
        mergedattrs = edgeattrs.setdefault(key, dict())
74
75
76
        # mergedattrs["links"] and attrs["links"]
        # could be two sequences that need to be concatenated
        dict_merge(mergedattrs, attrs, contatenate_sequences=True)
77
78
79
80
81
82
83
84

    for name, attrs in graph.nodes.items():
        newgraph.add_node(name, **attrs)
    for (source, target), mergedattrs in edgeattrs.items():
        newgraph.add_edge(source, target, **mergedattrs)
    return newgraph


85
def get_subgraphs(graph: networkx.DiGraph, **load_options):
86
    subgraphs = dict()
87
    for node_id, node_attrs in graph.nodes.items():
88
        task_type, task_info = inittask.task_executable_info(
89
            node_attrs, node_id=node_id, all=True
90
        )
91
92
        if task_type == "graph":
            g = load_graph(task_info["task_identifier"], **load_options)
93
94
95
96
97
            g.graph.graph["id"] = node_id
            node_label = node_attrs.get("label")
            if node_label:
                g.graph.graph["label"] = node_label
            subgraphs[node_id] = g
98
99
100
    return subgraphs


Wout De Nolf's avatar
Wout De Nolf committed
101
102
def _ewoks_jsonload_hook_pair(item):
    key, value = item
103
    if key in ("source", "target", "sub_source", "sub_target", "id", "sub_node"):
104
        value = node_id_from_json(value)
Wout De Nolf's avatar
Wout De Nolf committed
105
106
107
108
109
110
111
    return key, value


def ewoks_jsonload_hook(items):
    return dict(map(_ewoks_jsonload_hook_pair, items))


112
113
114
115
116
117
118
119
def abs_path(path, root_dir=None):
    if os.path.isabs(path):
        return path
    if root_dir:
        path = os.path.join(root_dir, path)
    return os.path.abspath(path)


120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class TaskGraph:
    """The API for graph analysis is provided by `networkx`.
    Any directed graph is supported (cyclic or acyclic).

    Loop over the dependencies of a task

    .. code-block:: python

        for source in taskgraph.predecessors(target):
            link_attrs = taskgraph.graph[source][target]

    Loop over the tasks dependent on a task

    .. code-block:: python

        for target in taskgraph.successors(source):
            link_attrs = taskgraph.graph[source][target]

    Instantiate a task

    .. code-block:: python

        task = taskgraph.instantiate_task(name, varinfo=varinfo, inputs=inputs)

    For acyclic graphs, sequential task execution can be done like this:

    .. code-block:: python

        taskgraph.execute()
    """

    GraphRepresentation = enum.Enum(
        "GraphRepresentation", "json_file json_dict json_string yaml"
    )

155
156
    def __init__(self, source=None, representation=None, **load_options):
        self.load(source=source, representation=representation, **load_options)
157
158

    def __repr__(self):
159
160
161
162
163
164
165
166
167
        return self.graph_label

    @property
    def graph_id(self):
        return self.graph.graph.get("id", qualname(type(self)))

    @property
    def graph_label(self):
        return self.graph.graph.get("label", self.graph_id)
168
169
170
171
172
173

    def __eq__(self, other):
        if not isinstance(other, type(self)):
            raise TypeError(other, type(other))
        return self.dump() == other.dump()

174
    def load(self, source=None, representation=None, root_dir=None):
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        """From persistent to runtime representation"""
        if representation is None:
            if isinstance(source, Mapping):
                representation = self.GraphRepresentation.json_dict
            elif isinstance(source, str):
                if source.endswith(".json"):
                    representation = self.GraphRepresentation.json_file
                else:
                    representation = self.GraphRepresentation.json_string
        if not source:
            graph = networkx.DiGraph()
        elif isinstance(source, networkx.Graph):
            graph = source
        elif isinstance(source, TaskGraph):
            graph = source.graph
        elif representation == self.GraphRepresentation.json_dict:
            set_graph_defaults(source)
            graph = networkx.readwrite.json_graph.node_link_graph(source)
        elif representation == self.GraphRepresentation.json_file:
194
            source = abs_path(source, root_dir)
195
            with open(source, mode="r") as f:
Wout De Nolf's avatar
Wout De Nolf committed
196
                source = json.load(f, object_pairs_hook=ewoks_jsonload_hook)
197
198
199
            set_graph_defaults(source)
            graph = networkx.readwrite.json_graph.node_link_graph(source)
        elif representation == self.GraphRepresentation.json_string:
Wout De Nolf's avatar
Wout De Nolf committed
200
            source = json.loads(source, object_pairs_hook=ewoks_jsonload_hook)
201
202
203
            set_graph_defaults(source)
            graph = networkx.readwrite.json_graph.node_link_graph(source)
        elif representation == self.GraphRepresentation.yaml:
204
            source = abs_path(source, root_dir)
205
206
207
208
209
210
211
            graph = networkx.readwrite.read_yaml(source)
        else:
            raise TypeError(representation, type(representation))

        if not networkx.is_directed(graph):
            raise TypeError(graph, type(graph))

212
        subgraphs = get_subgraphs(graph, root_dir=root_dir)
213
214
        if subgraphs:
            # Extract
215
216
            edges, update_attrs = extract_graph_nodes(graph, subgraphs)
            graph = flatten_multigraph(graph)
217
218
219
220
221
222

            # Merged
            self.graph = graph
            graphs = [self] + list(subgraphs.values())
            rename_nodes = [False] + [True] * len(subgraphs)
            graph = merge_graphs(
223
                graphs,
224
                graph_attrs=graph.graph,
225
226
                rename_nodes=rename_nodes,
                root_dir=root_dir,
227
228
229
230
231
            ).graph

            # Re-link
            add_subgraph_links(graph, edges, update_attrs)

232
        self.graph = flatten_multigraph(graph)
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        self.validate_graph()

    def dump(self, destination=None, representation=None, **kw):
        """From runtime to persistent representation"""
        if representation is None:
            if isinstance(destination, str) and destination.endswith(".json"):
                representation = self.GraphRepresentation.json_file
            else:
                representation = self.GraphRepresentation.json_dict
        if representation == self.GraphRepresentation.json_dict:
            return networkx.readwrite.json_graph.node_link_data(self.graph)
        elif representation == self.GraphRepresentation.json_file:
            dictrepr = self.dump()
            with open(destination, mode="w") as f:
                json.dump(dictrepr, f, **kw)
            return destination
        elif representation == self.GraphRepresentation.json_string:
            dictrepr = self.dump()
            return json.dumps(dictrepr, **kw)
        elif representation == self.GraphRepresentation.yaml:
            return networkx.readwrite.write_yaml(self.graph, destination, **kw)
        else:
            raise TypeError(representation, type(representation))

    def serialize(self):
        return self.dump(representation=self.GraphRepresentation.json_string)

    @property
    def is_cyclic(self):
        return not networkx.is_directed_acyclic_graph(self.graph)

    @property
    def has_conditional_links(self):
        for attrs in self.graph.edges.values():
            if attrs.get("conditions") or attrs.get("on_error"):
                return True
        return False

271
272
273
274
275
276
    def instantiate_task(
        self,
        node_id: NodeIdType,
        varinfo: Optional[dict] = None,
        inputs: Optional[dict] = None,
    ) -> Task:
277
        """Named arguments are dynamic input and Variable config.
278
        Default input from the persistent representation are added internally.
279
        """
280
        # Dynamic input has priority over default input
281
        nodeattrs = self.graph.nodes[node_id]
282
        return inittask.instantiate_task(
283
            nodeattrs, node_id=node_id, varinfo=varinfo, inputs=inputs
284
285
        )

286
287
288
289
290
291
    def instantiate_task_static(
        self,
        node_id: NodeIdType,
        tasks: Optional[dict] = None,
        varinfo: Optional[dict] = None,
    ) -> Task:
292
293
        """Instantiate destination task while no access to the dynamic inputs.
        Side effect: `tasks` will contain all predecessors.
294
295
296
297
298
299
        """
        if self.is_cyclic:
            raise RuntimeError(f"{self} is cyclic")
        if tasks is None:
            tasks = dict()
        dynamic_inputs = dict()
300
        for inputnode in self.predecessors(node_id):
301
302
303
304
305
            inputtask = tasks.get(inputnode, None)
            if inputtask is None:
                inputtask = self.instantiate_task_static(
                    inputnode, tasks=tasks, varinfo=varinfo
                )
306
            link_attrs = self.graph[inputnode][node_id]
307
308
309
            inittask.add_dynamic_inputs(
                dynamic_inputs, link_attrs, inputtask.output_variables
            )
310
311
        task = self.instantiate_task(node_id, varinfo=varinfo, inputs=dynamic_inputs)
        tasks[node_id] = task
312
313
        return task

314
    def successors(self, node_id: NodeIdType, **include_filter):
315
        yield from self._iter_downstream_nodes(
316
            node_id, recursive=False, **include_filter
317
318
        )

319
    def descendants(self, node_id: NodeIdType, **include_filter):
320
        yield from self._iter_downstream_nodes(
321
            node_id, recursive=True, **include_filter
322
323
        )

324
325
    def predecessors(self, node_id: NodeIdType, **include_filter):
        yield from self._iter_upstream_nodes(node_id, recursive=False, **include_filter)
326

327
328
    def ancestors(self, node_id: NodeIdType, **include_filter):
        yield from self._iter_upstream_nodes(node_id, recursive=True, **include_filter)
329

330
331
    def has_successors(self, node_id: NodeIdType, **include_filter):
        return self._iterator_has_items(self.successors(node_id, **include_filter))
332

333
334
    def has_descendants(self, node_id: NodeIdType, **include_filter):
        return self._iterator_has_items(self.descendants(node_id, **include_filter))
335

336
337
    def has_predecessors(self, node_id: NodeIdType, **include_filter):
        return self._iterator_has_items(self.predecessors(node_id, **include_filter))
338

339
340
    def has_ancestors(self, node_id: NodeIdType, **include_filter):
        return self._iterator_has_items(self.ancestors(node_id, **include_filter))
341
342
343
344
345
346
347
348
349

    @staticmethod
    def _iterator_has_items(iterator):
        try:
            next(iterator)
            return True
        except StopIteration:
            return False

350
351
    def _iter_downstream_nodes(self, node_id: NodeIdType, **kw):
        yield from self._iter_nodes(node_id, upstream=False, **kw)
352

353
354
    def _iter_upstream_nodes(self, node_id: NodeIdType, **kw):
        yield from self._iter_nodes(node_id, upstream=True, **kw)
355
356
357

    def _iter_nodes(
        self,
358
        node_id: NodeIdType,
359
360
361
362
363
364
365
366
367
368
369
370
371
        upstream=False,
        recursive=False,
        _visited=None,
        **include_filter,
    ):
        """Recursion is not stopped by the node or link filters.
        Recursion is stopped by either not having any successors/predecessors
        or encountering a node that has been visited already.
        The original node on which we start iterating is never included.
        """
        if recursive:
            if _visited is None:
                _visited = set()
372
            elif node_id in _visited:
373
                return
374
            _visited.add(node_id)
375
376
377
378
        if upstream:
            iter_next_nodes = self.graph.predecessors
        else:
            iter_next_nodes = self.graph.successors
379
380
        for next_id in iter_next_nodes(node_id):
            node_is_included = self._filter_node(next_id, **include_filter)
381
            if upstream:
382
                link_is_included = self._filter_link(next_id, node_id, **include_filter)
383
            else:
384
                link_is_included = self._filter_link(node_id, next_id, **include_filter)
385
            if node_is_included and link_is_included:
386
                yield next_id
387
388
            if recursive:
                yield from self._iter_nodes(
389
                    next_id,
390
391
392
393
394
395
396
397
                    upstream=upstream,
                    recursive=True,
                    _visited=_visited,
                    **include_filter,
                )

    def _filter_node(
        self,
398
        node_id: NodeIdType,
399
400
401
402
403
404
405
        node_filter=None,
        node_has_predecessors=None,
        node_has_successors=None,
        **linkfilter,
    ):
        """Filters are combined with the logical AND"""
        if callable(node_filter):
406
            if not node_filter(node_id):
407
408
                return False
        if node_has_predecessors is not None:
409
            if self.has_predecessors(node_id) != node_has_predecessors:
410
411
                return False
        if node_has_successors is not None:
412
            if self.has_successors(node_id) != node_has_successors:
413
414
415
416
417
                return False
        return True

    def _filter_link(
        self,
418
419
        source_id: NodeIdType,
        target_id: NodeIdType,
420
421
422
423
424
425
426
427
428
        link_filter=None,
        link_has_on_error=None,
        link_has_conditions=None,
        link_is_conditional=None,
        link_has_required=None,
        **nodefilter,
    ):
        """Filters are combined with the logical AND"""
        if callable(link_filter):
429
            if not link_filter(source_id, target_id):
430
431
                return False
        if link_has_on_error is not None:
432
            if self._link_has_on_error(source_id, target_id) != link_has_on_error:
433
434
                return False
        if link_has_conditions is not None:
435
            if self._link_has_conditions(source_id, target_id) != link_has_conditions:
436
437
                return False
        if link_is_conditional is not None:
438
            if self._link_is_conditional(source_id, target_id) != link_is_conditional:
439
440
                return False
        if link_has_required is not None:
441
            if self._link_has_required(source_id, target_id) != link_has_required:
442
443
444
                return False
        return True

445
446
    def _link_has_conditions(self, source_id: NodeIdType, target_id: NodeIdType):
        link_attrs = self.graph[source_id][target_id]
447
448
        return bool(link_attrs.get("conditions", False))

449
450
    def _link_has_on_error(self, source_id: NodeIdType, target_id: NodeIdType):
        link_attrs = self.graph[source_id][target_id]
451
452
        return bool(link_attrs.get("on_error", False))

453
454
    def _link_has_required(self, source_id: NodeIdType, target_id: NodeIdType):
        link_attrs = self.graph[source_id][target_id]
455
456
        return bool(link_attrs.get("required", False))

457
458
    def _link_is_conditional(self, source_id: NodeIdType, target_id: NodeIdType):
        link_attrs = self.graph[source_id][target_id]
459
460
461
462
        return bool(
            link_attrs.get("on_error", False) or link_attrs.get("conditions", False)
        )

463
464
    def link_is_required(self, source_id: NodeIdType, target_id: NodeIdType):
        if self._link_has_required(source_id, target_id):
465
            return True
466
        if self._link_is_conditional(source_id, target_id):
467
            return False
468
        return self._node_is_required(source_id)
469

470
    def _node_is_required(self, node_id: NodeIdType):
471
        return not self.has_ancestors(
472
            node_id, link_has_required=False, link_is_conditional=True
473
474
        )

475
476
477
478
    def _required_predecessors(self, target_id: NodeIdType):
        for source_id in self.predecessors(target_id):
            if self.link_is_required(source_id, target_id):
                yield source_id
479

480
481
    def _has_required_predecessors(self, node_id: NodeIdType):
        return self._iterator_has_items(self._required_predecessors(node_id))
482

483
    def _has_required_static_inputs(self, node_id: NodeIdType):
484
        """Returns True when the default inputs cover all required inputs."""
485
        node_attrs = self.graph.nodes[node_id]
486
487
488
489
490
491
        inputs_complete = node_attrs.get("inputs_complete", None)
        if isinstance(inputs_complete, bool):
            # method and script tasks always have an empty `required_input_names`
            # although they may have required input. This keyword is used the
            # manually indicate that all required inputs are statically provided.
            return inputs_complete
492
        taskclass = inittask.get_task_class(node_attrs, node_id=node_id)
493
        static_inputs = {d["name"] for d in node_attrs.get("default_inputs", list())}
494
        return not (set(taskclass.required_input_names()) - static_inputs)
495

496
    def start_nodes(self) -> Set[NodeIdType]:
497
        nodes = set(
498
499
500
            node_id
            for node_id in self.graph.nodes
            if not self.has_predecessors(node_id)
501
502
503
504
        )
        if nodes:
            return nodes
        return set(
505
506
507
508
            node_id
            for node_id in self.graph.nodes
            if self._has_required_static_inputs(node_id)
            and not self._has_required_predecessors(node_id)
509
510
        )

511
    def end_nodes(self) -> Set[NodeIdType]:
512
        nodes = set(
513
            node_id for node_id in self.graph.nodes if not self.has_successors(node_id)
514
515
516
517
        )
        if nodes:
            return nodes
        return set(
518
519
520
            node_id
            for node_id in self.graph.nodes
            if self._node_has_noncovered_conditions(node_id)
521
522
        )

523
524
    def _node_has_noncovered_conditions(self, source_id: NodeIdType) -> bool:
        links = self._get_node_expanded_conditions(source_id)
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        has_complement = [False] * len(links)

        default_complements = {CONDITIONS_ELSE_VALUE}
        complements = {
            CONDITIONS_ELSE_VALUE: None,
            True: {False, CONDITIONS_ELSE_VALUE},
            False: {True, CONDITIONS_ELSE_VALUE},
        }

        for i, conditions1 in enumerate(links):
            if has_complement[i]:
                continue
            for j in range(i + 1, len(links)):
                conditions2 = links[j]
                if self._conditions_are_complementary(
                    conditions1, conditions2, default_complements, complements
                ):
                    has_complement[i] = True
                    has_complement[j] = True
                    break
            if not has_complement[i]:
                return True
        return False

    @staticmethod
    def _conditions_are_complementary(
        conditions1, conditions2, default_complements, complements
    ):
        for varname, value in conditions1.items():
            complementary_values = complements.get(value, default_complements)
            if complementary_values is None:
                # Any value is complementary
                continue
            if conditions2[varname] not in complementary_values:
                return False
        return True

562
    def _get_node_expanded_conditions(self, source_id: NodeIdType):
Wout De Nolf's avatar
Wout De Nolf committed
563
564
565
        """Ensure that conditional link starting from a source node has
        the same set of output names.
        """
566
        links = [
567
568
            self.graph[source_id][target_id]["conditions"]
            for target_id in self.successors(source_id, link_has_conditions=True)
569
        ]
Wout De Nolf's avatar
Wout De Nolf committed
570
        all_condition_names = {
571
572
            item["source_output"] for conditions in links for item in conditions
        }
573
        for conditions in links:
Wout De Nolf's avatar
Wout De Nolf committed
574
575
            link_condition_names = {item["source_output"] for item in conditions}
            for name in all_condition_names - link_condition_names:
576
577
578
                conditions.append(
                    {"source_output": name, "value": CONDITIONS_ELSE_VALUE}
                )
579
580
581
        return links

    def validate_graph(self):
582
583
        for node_id, node_attrs in self.graph.nodes.items():
            inittask.validate_task_executable(node_attrs, node_id=node_id)
584
585

            # Isolated nodes do no harm so comment this
586
587
            # if len(graph.nodes) > 1 and not node_has_links(graph, node_id):
            #    raise ValueError(f"Node {repr(node_id)} has no links")
588
589

            inputs_from_required = dict()
590
591
            for source_id in self._required_predecessors(node_id):
                link_attrs = self.graph[source_id][node_id]
592
                arguments = link_attrs.get("data_mapping", list())
593
594
                for arg in arguments:
                    try:
595
                        name = arg["target_input"]
596
597
                    except KeyError:
                        raise KeyError(
598
                            f"Argument '{arg}' of link '{source_id}' -> '{node_id}' is missing an 'input' key"
599
                        ) from None
600
601
                    other_source_id = inputs_from_required.get(name)
                    if other_source_id:
602
                        raise ValueError(
603
                            f"Node {repr(source_id)} and {repr(other_source_id)} both connect to the input {repr(name)} of {repr(node_id)}"
604
                        )
605
                    inputs_from_required[name] = source_id
606
607
608
609
610

        for (source, target), linkattrs in self.graph.edges.items():
            err_msg = (
                f"Link {source}->{target}: '{{}}' and '{{}}' cannot be used together"
            )
611
612
            if linkattrs.get("map_all_data") and linkattrs.get("data_mapping"):
                raise ValueError(err_msg.format("map_all_data", "data_mapping"))
613
614
615
616
617
618
619
620
621
            if linkattrs.get("on_error") and linkattrs.get("conditions"):
                raise ValueError(err_msg.format("on_error", "conditions"))

    def topological_sort(self):
        """Sort node names for sequential instantiation+execution of DAGs"""
        if self.is_cyclic:
            raise RuntimeError("Sorting nodes is not possible for cyclic graphs")
        yield from networkx.topological_sort(self.graph)

622
    def execute(self, varinfo: Optional[dict] = None, raise_on_error: bool = True):
623
624
625
626
627
628
629
630
        """Sequential execution of DAGs"""
        if self.is_cyclic:
            raise RuntimeError("Cannot execute cyclic graphs")
        if self.has_conditional_links:
            raise RuntimeError("Cannot execute graphs with conditional links")
        tasks = dict()
        for node in self.topological_sort():
            task = self.instantiate_task_static(node, tasks=tasks, varinfo=varinfo)
631
            task.execute(raise_on_error=raise_on_error)
632
        return tasks