node.py 13.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# coding: utf-8
# /*##########################################################################
#
# Copyright (c) 2017 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/

__authors__ = ["H.Payno"]
__license__ = "MIT"
__date__ = "29/05/2017"

import functools
import logging
import traceback
import inspect
34
import importlib
payno's avatar
payno committed
35
from typing import Union
36
from importlib.machinery import SourceFileLoader
37

38
_logger = logging.getLogger(__name__)
39
40
41
42
43

global next_node_free_idF
next_node_free_id = 0


payno's avatar
payno committed
44
def get_next_node_free_id() -> int:
45
46
47
48
49
50
51
52
53
54
    global next_node_free_id
    _id = next_node_free_id
    next_node_free_id += 1
    return _id


def trace_unhandled_exceptions(func):
    @functools.wraps(func)
    def wrapped_func(*args, **kwargs):
        try:
payno's avatar
payno committed
55
            out_data = func(*args, **kwargs)
56
57
58
59
60
61
62
63
64
        except Exception as e:
            _logger.exception(e)
            errorMessage = '{0}'.format(e)
            traceBack = traceback.format_exc()
            return WorkflowException(
                    msg=errorMessage,
                    traceBack=traceBack,
                    data=args[1]
            )
payno's avatar
payno committed
65
        return out_data
66
67
68
69
70
71

    return wrapped_func


class Node(object):
    """
72
    Node in the `.Scheme`. Will be associated to a core process.
73

74
75
76
77
    :param processing_pt: pointer to a class or a function or str defining the
                          callback. If the callback is a class then the handler
                          should be defined or the class should have a default
                          'process' function that will be called by default.
78
79
80
81
82
83
84
85
    :param int id: unique id of the node.
    :param dict properties: properties of the node
    :param str luigi_task: luigi task associate to this node
    """

    need_stop_join = False
    """flag to stop the node only when receive the 'stop' signal"""

86
87
88
    _JSON_PROCESS_PT = 'process_pt'
    _JSON_ID = 'id'
    _JSON_PROPERTIES = 'properties'
Olof Svensson's avatar
Olof Svensson committed
89
    _JSON_ERROR_HANDLER = 'error_handler'
90

payno's avatar
payno committed
91
    def __init__(self, processing_pt, id: Union[None, int] = None,
92
                 properties: Union[None, dict] = None,
93
                 error_handler=None):
94
        self.id = get_next_node_free_id() if id is None else id
95
96
97
98
99
100
101
102
103
        """int of the node id"""
        self.properties = properties or {}
        """dict of the node properties"""
        self.upstream_nodes = set()
        """Set of upstream nodes"""
        self.downstream_nodes = set()
        """Set of downstream nodes"""
        self.__process_instance = None
        """"""
104
        self._process_pt = processing_pt
105
        """process instance"""
106
107
108
        self._handlers = {}
        """handlers with link name as key and callback as value.
        The default handler is store under the 'None' value"""
109
110
111
112
        self._input_type_to_name = {}
        """link input type to a signal name"""
        self._output_type_to_name = {}
        """link output type to a signal name"""
113
        self._error_handler = error_handler
payno's avatar
payno committed
114
        self.out_data = None
115

116
117
118
119
120
121
122
123
124
125
126
127
    def get_input_channel_name(self, data_object):
        for dtype, channel_name in self._input_type_to_name.items():
            if isinstance(data_object, dtype):
                return channel_name
        return None

    def get_output_channel_name(self, data_object):
        for dtype, channel_name in self._output_type_to_name.items():
            if isinstance(data_object, dtype):
                return channel_name
        return None

128
    @property
payno's avatar
payno committed
129
    def handlers(self) -> dict:
130
131
132
        return self._handlers

    @property
133
134
    def process_pt(self):
        return self._process_pt
135

136
137
138
139
    @property
    def class_instance(self):
        return self.__process_instance

payno's avatar
payno committed
140
    def isfinal(self) -> bool:
141
142
143
144
145
        """

        :return: True if the node is at the end of a branch.
        :rtype: bool
        """
146
147
        return len(self.downstream_nodes) is 0

payno's avatar
payno committed
148
    def isstart(self) -> bool:
149
150
151
152
153
        """

        :return: True if the node does not requires any input
        :rtype: bool
        """
154
155
        return len(self.upstream_nodes) is 0

payno's avatar
payno committed
156
    def load_handlers(self) -> None:
157
158
159
160
161
        """
        load handlers from the `processing_pt` defined.
        For callable it will always be `processing_pt`.
        But for not-callable class it will be class function defined in the
        `inputs` variable.
payno's avatar
payno committed
162
163
164

        :raises: ValueError if unable to find sme handlers in the classes
                 definition
165
        """
166
        self._handlers.clear()
167
168
        self._input_type_to_name.clear()
        self._output_type_to_name.clear()
169
170
171
172
173
174
175
176
177
178
        assert self._process_pt is not None
        if callable(self._process_pt):
            self.__process_instance = self._process_pt
            self._handlers[None] = self._process_pt
        else:
            if not type(self._process_pt) is str:
                raise ValueError('process_pt should be a callable or path to a class or function')
            else:
                sname = self._process_pt.rsplit('.')
                if not (len(sname) > 1):
179
                    raise ValueError(self._process_pt + ' is not recognized as a valid name')
180
181
182
183
184
185
186
187
188
189
                class_name = sname[-1]
                del sname[-1]
                module_name = '.'.join(sname)
                if module_name.endswith('.py'):
                    # warning: in this case the file should not have any relative
                    module = SourceFileLoader(module_name,
                                              module_name).load_module()
                else:
                    module = importlib.import_module(module_name)

190
191
                class_or_fct = getattr(module, class_name)
                if inspect.isclass(class_or_fct):
192
                    _logger.debug('instanciate ' + str(class_or_fct))
193
194
195
                    self.__process_instance = class_or_fct()
                else:
                    self.__process_instance = class_or_fct
196
197
198
199
200
                if callable(self.__process_instance):
                    self._handlers[None] = self.__process_instance
                # manage the case where a class has several input handler
                if hasattr(self.__process_instance, 'inputs'):
                    for input_ in self.__process_instance.inputs:
201
202
203
204
205
206
207
                        input_name, input_type, input_handler = input_[:3]
                        _logger.debug('[node: %s] add input_name: %s, '
                                     'input_type: %s, input_handler: %s' % (self._process_pt, input_name, input_type, input_handler))
                        if str(input_type) in self._input_type_to_name:
                            raise ValueError('Several input name found for the '
                                             'same input type. This case is not managed.')
                        self._input_type_to_name[input_type] = input_name
208
                        self._handlers[input_name] = input_handler
209
                        # self._handlers[input_name] = getattr(self.__process_instance, input_handler)
210
211
212
213
214
215
216
217
218
219
220
221
                if hasattr(self.__process_instance, 'outputs'):
                    for output_ in self.__process_instance.outputs:
                        output_name, output_type, output_handler = output_[:3]
                        _logger.debug('[node: %s] add output_name: %s, '
                                     'output_type: %s, output_handler: %s' % (
                                     self._process_pt, input_name, input_type,
                                     input_handler))
                        if output_type in self._output_type_to_name:
                            raise ValueError(
                                'Several output name found for the '
                                'same output type. This case is not managed.')
                        self._output_type_to_name[output_type] = output_name
222
223
224
225
226

        if len(self._handlers) == 0:
            raise ValueError('Fail to init handlers, none defined for ' + str(self._process_pt))

    @staticmethod
payno's avatar
payno committed
227
    def execute(process_pt, properties: dict, input_name: str,
228
                input_data: object) -> tuple:
229
230
231
232
233
234
235
236
237
238
        """
        Create an instance of a node with 'process_pt' and execute it with the
        given input_name, properties and input_data.

        :param str process_pt: name of the process to execute
         (can be a module.class name, or a module.function)
        :param dict properties: process properties
        :param str input_name: name of the input data
        :param input_data: input data :warning: Should be serializable

239
240
        :return: (output data type, output data)
                 :warning: Should be serializable
241
        """
242
243
        node = Node(processing_pt=process_pt, properties=properties)
        node.load_handlers()
244
245
        logging.info('start execution of {0} with {1} through channel {2}'
                     ''.format(str(process_pt), input_data, input_name))
payno's avatar
payno committed
246
247
248
249
        if hasattr(node.__process_instance, 'set_properties'):
            node.__process_instance.set_properties(properties)
        else:
            raise ValueError('no function set properties found')
250

251
        if input_name in node.handlers:
252
            out = getattr(node.__process_instance, node.handlers[input_name])(input_data)
253
        elif None in node.handlers:
254
            out = getattr(node.__process_instance, node.handlers[None])(input_data)
255
        else:
256
257
258
259
260
261
262
263
264
            err = '"{0}" channel is not managed by {1}'.format(input_name, node._process_pt)
            raise KeyError(err)

        # retrieve output channel
        if out is None:
            output_channel = None
        else:
            output_channel = node.get_output_channel_name(out)

265
        if hasattr(out, 'to_dict'):
266
            return output_channel, out.to_dict()
267
        else:
268
            return output_channel, out
269

payno's avatar
payno committed
270
    def to_json(self) -> dict:
271
272
273
274
275
        """

        :return: json description of the node
        :rtype: dict
        """
Olof Svensson's avatar
Olof Svensson committed
276
        res = {
277
278
279
280
            self._JSON_PROCESS_PT: self.process_pt,
            self._JSON_ID: self.id,
            self._JSON_PROPERTIES: self.properties,
        }
Olof Svensson's avatar
Olof Svensson committed
281
282
283
284
285
286
287
288
        res.update(self._get_error_handler_json())
        return res

    def _get_error_handler_json(self):
        error_handler_json = self._error_handler.to_json() if self._error_handler else {}
        return {
            self._JSON_ERROR_HANDLER: error_handler_json,
        }
289
290

    @staticmethod
payno's avatar
payno committed
291
    def load_node_info_from_json(json_data: dict) -> tuple:
292
        """
293
        load fom json stream the Node Information
294
295

        :param json_data: node description
296
297
        :return: node id, properties, pointer to the process to run
        :rtype: tuple
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        """
        # load properties
        if Node._JSON_PROPERTIES not in json_data:
            _logger.error('Missing node properties in json description')
            _properties = None
        else:
            _properties = json_data[Node._JSON_PROPERTIES]
            assert type(_properties) is dict
        # load id
        if Node._JSON_ID not in json_data:
            _logger.error('Missing node id in json description')
            _id = None
        else:
            _id = json_data[Node._JSON_ID]
            assert type(_id) is int
        # load process_pt
        if Node._JSON_PROCESS_PT not in json_data:
            _logger.error('Missing node process_pt in json description')
            _process_pt = None
        else:
            _process_pt = json_data[Node._JSON_PROCESS_PT]
319
        return _id, _properties, _process_pt
320

321
    @staticmethod
payno's avatar
payno committed
322
    def from_json(json_data: dict):
323
324
325
326
327
328
329
330
        """

        :param json_data: node description
        :return: New node created from the json description
        :rtype: Node
        :raise ValueError: if properties or id or processing_pt missing
        """
        _id, _properties, _process_pt = Node.load_node_info_from_json(json_data)
331
332
333
334
335
336
337
338
        if _properties is None or _id is None or _process_pt is None:
            raise ValueError('Unable to create Node from json, core information '
                             'are missing')
        else:
            return Node(id=_id, properties=_properties,
                        processing_pt=_process_pt)

    def __str__(self):
payno's avatar
payno committed
339
        return "node %s - %s" % (self.id, self._process_pt)
340

341
342
343
344
345
346
347
348
349

class WorkflowException(Exception):
    def __init__(self, traceBack="", data=None, msg=None):
        if data is None:
            data = {}
        super(WorkflowException, self).__init__(msg)
        self.errorMessage = msg
        self.data = data
        self.traceBack = traceBack
Olof Svensson's avatar
Olof Svensson committed
350
351
352
353
354
355
356
357
358
359
360
361


class ErrorHandler(Node):
    '''
    TODO
    '''
    def __init__(self, processing_pt, id=None, properties=None):
        super(ErrorHandler, self).__init__(processing_pt=processing_pt, id=id,
                                           properties=properties,
                                           error_handler=None)

    def _get_error_handler_json(self):
362
        return {}