utils.py 2.68 KB
Newer Older
1
from typing import Any, Dict, Optional
2
3
4
5
import networkx
from pprint import pprint
import matplotlib.pyplot as plt
from ewokscore import load_graph
6
7
8
from ewokscore.graph import TaskGraph
from ewokscore.node import NodeIdType
from ewokscore.task import Task
9
from ewokscore.variable import value_from_transfer
10
11


12
13
14
15
16
17
def assert_taskgraph_result(
    taskgraph: TaskGraph,
    expected: Dict[NodeIdType, Any],
    varinfo: Optional[dict] = None,
    tasks: Optional[Dict[NodeIdType, Task]] = None,
):
18
19
20
    taskgraph = load_graph(taskgraph)
    assert not taskgraph.is_cyclic, "Can only check DAG results"

Wout De Nolf's avatar
Wout De Nolf committed
21
22
23
    if tasks is None:
        tasks = dict()

24
    for node in taskgraph.graph.nodes:
Wout De Nolf's avatar
Wout De Nolf committed
25
26
27
28
29
30
31
        task = tasks.get(node, None)
        if task is None:
            assert varinfo, "Need 'varinfo' to load task output"
            task = taskgraph.instantiate_task_static(node, tasks=tasks, varinfo=varinfo)
        assert_task_result(task, node, expected)


32
33
def assert_task_result(task: Task, node_id: NodeIdType, expected: dict):
    expected_value = expected.get(node_id)
Wout De Nolf's avatar
Wout De Nolf committed
34
    if expected_value is None:
35
        assert not task.done, node_id
Wout De Nolf's avatar
Wout De Nolf committed
36
    else:
37
        assert task.done, node_id
Wout De Nolf's avatar
Wout De Nolf committed
38
        try:
39
            assert task.output_values == expected_value, node_id
Wout De Nolf's avatar
Wout De Nolf committed
40
41
42
        except AssertionError:
            raise
        except Exception as e:
43
            raise RuntimeError(f"{node_id} does not have a result") from e
Wout De Nolf's avatar
Wout De Nolf committed
44
45


46
47
48
49
50
def assert_workflow_result(
    results: Dict[NodeIdType, Any],
    expected: Dict[NodeIdType, Any],
    varinfo: Optional[dict] = None,
):
51
    for node_id, expected_result in expected.items():
52
        if expected_result is None:
53
            assert node_id not in results
54
            continue
55
        result = results[node_id]
56
57
58
        if isinstance(result, Task):
            assert result.done, node_id
            result = result.output_values
59
60
61
62
63
        for output_name, expected_value in expected_result.items():
            value = result[output_name]
            assert_result(value, expected_value, varinfo=varinfo)


64
65
66
def assert_workflow_merged_result(
    result: dict, expected: Dict[NodeIdType, Any], varinfo: Optional[dict] = None
):
Wout De Nolf's avatar
Wout De Nolf committed
67
    for output_name, expected_value in expected.items():
68
        value = result[output_name]
69
70
71
        assert_result(value, expected_value, varinfo=varinfo)


72
def assert_result(value, expected_value, varinfo: Optional[dict] = None):
73
    value = value_from_transfer(value, varinfo=varinfo)
74
    assert value == expected_value
75
76
77
78
79
80
81
82
83
84


def show_graph(graph, stdout=True, plot=True, show=True):
    taskgraph = load_graph(graph)
    if stdout:
        pprint(taskgraph.dump())
    if plot:
        networkx.draw(taskgraph.graph, with_labels=True, font_size=10)
        if show:
            plt.show()