mback_norm.py 7.09 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# coding: utf-8
# /*##########################################################################
#
# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/
"""wrapper to the larch mback process"""

payno's avatar
payno committed
27
28
from est.core.types import Spectrum, XASObject
from est.core.process.process import Process
29
from est.core.process.process import _NexusDatasetDef
30
from larch.xafs.mback import mback_norm
31
from larch.xafs.pre_edge import preedge
32
33
34
import multiprocessing
import functools
import logging
payno's avatar
payno committed
35

36
37
38
39
40
41
42
_logger = logging.getLogger(__name__)

_DEBUG = True
if _DEBUG:
    from larch.symboltable import Group


payno's avatar
payno committed
43
44
45
46
47
48
49
50
def process_spectr_mback_norm(
    spectrum,
    configuration,
    overwrite=True,
    callbacks=None,
    output=None,
    output_dict=None,
):
51
52
    """

53
54
55
56
57
58
59
    :param spectrum: spectrum to process
    :type: :class:`.Spectrum`
    :param configuration: configuration of the pymca normalization
    :type: dict
    :param overwrite: False if we want to return a new Spectrum instance
    :type: bool
    :param callback: callback to execute.
60
61
    :param output: list to store the result, needed for pool processing
    :type: multiprocessing.manager.list
62
63
64
    :param output_dict: key is input spectrum, value is index in the output
                        list.
    :type: dict
65
66
67
    :return: processed spectrum
    :rtype: tuple (configuration, spectrum)
    """
payno's avatar
payno committed
68
    _logger.debug("start mback_norm on spectrum (%s, %s)" % (spectrum.x, spectrum.y))
69
    assert isinstance(spectrum, Spectrum)
payno's avatar
payno committed
70
71
72
73
74
    if not hasattr(spectrum, "norm"):
        _logger.error(
            "spectrum doesn't have norm. Maybe you meed to compute "
            "pre_edge first? Unable to compute mback_norm."
        )
75
        return None, None
payno's avatar
payno committed
76
77
78
79
80
    if not hasattr(spectrum, "pre_edge"):
        _logger.error(
            "spectrum doesn't have norm. Maybe you meed to compute "
            "pre_edge first? Unable to compute mback_norm."
        )
81
82
        return None, None

83
    _conf = configuration
payno's avatar
payno committed
84
85
    if "mback_norm" in _conf:
        _conf = _conf["mback_norm"]
86
    opts = {}
87
88
89
90
91
92
93
94
95

    if _DEBUG is True:
        assert isinstance(spectrum, Group)
    if overwrite:
        _spectrum = spectrum
    else:
        _spectrum = Spectrum().load_frm_dict(spectrum.to_dict())
    # TODO: computing each time preedge should be avoidable
    pre_edge_details = preedge(_spectrum.energy, _spectrum.mu)
payno's avatar
payno committed
96
97
98
99
100
101
102
103
104
105
106
    for opt_name in (
        "z",
        "edge",
        "e0",
        "pre1",
        "pre2",
        "norm1",
        "norm2",
        "nnorm",
        "nvict",
    ):
107
108
        if opt_name in _conf:
            opts[opt_name] = _conf[opt_name]
109
110
        elif pre_edge_details is not None and opt_name in pre_edge_details:
            opts[opt_name] = pre_edge_details[opt_name]
111
112

    mback_norm(_spectrum, group=_spectrum, **opts)
113
114
115
    if callbacks:
        for callback in callbacks:
            callback()
116
117
118
119
120
121
122
123
124
125
126
    return configuration, _spectrum


def larch_mback_norm(xas_obj):
    """

    :param xas_obj: object containing the configuration and spectra to process
    :type: Union[XASObject, dict]
    :return: spectra dict
    :rtype: XASObject
    """
payno's avatar
payno committed
127
128
    mback_obj = Larch_mback_norm(inputs={"xas_obj": xas_obj})
    return mback_obj.run()
129

payno's avatar
payno committed
130

131
_USE_MULTIPROCESSING_POOL = False
payno's avatar
payno committed
132
# note: we cannot use multiprocessing pool with pypushflow for now.
133
134
135


class Larch_mback_norm(Process):
136

137
    _INPUT_NAMES = set(["xas_obj"])
138

139
    _OUTPUT_NAMES = set(["xas_obj"])
140

141
142
    def __init__(self, varinfo=None, **inputs):
        Process.__init__(self, name="mback_norm", varinfo=varinfo, **inputs)
143

144
    def set_properties(self, properties):
payno's avatar
payno committed
145
146
        if "_larchSettings" in properties:
            self._settings = properties["_larchSettings"]
147

148
149
    def run(self):
        xas_obj = self.inputs.xas_obj
150
        if xas_obj is None:
151
            raise ValueError("xas_obj should be provided")
152
153
        _xas_obj = self.getXasObject(xas_obj=xas_obj)
        if self._settings:
payno's avatar
payno committed
154
            _xas_obj.configuration["mback_norm"] = self._settings
155
156
157
158
159

        self._advancement.reset(max_=_xas_obj.n_spectrum)
        self._advancement.startProcess()
        self._pool_process(xas_obj=_xas_obj)
        self._advancement.endProcess()
160
161
162
163
        data_keys = [
            _NexusDatasetDef("mback_mu"),
            _NexusDatasetDef("norm_mback"),
        ]
payno's avatar
payno committed
164
165
        if _xas_obj.n_spectrum > 0 and hasattr(_xas_obj.spectra[0], "edge_step"):
            data_keys += ["norm"]
166
        self.register_process(_xas_obj, data_keys=data_keys)
167
        self.outputs.xas_obj = _xas_obj.to_dict()
168
169
170
171
172
173
        return _xas_obj

    def _pool_process(self, xas_obj):
        assert isinstance(xas_obj, XASObject)
        if not _USE_MULTIPROCESSING_POOL:
            for spectrum in xas_obj.spectra:
payno's avatar
payno committed
174
175
176
177
178
179
                process_spectr_mback_norm(
                    spectrum=spectrum,
                    configuration=xas_obj.configuration,
                    callbacks=self.callbacks,
                    overwrite=True,
                )
180
181
        else:
            from multiprocessing import Manager
payno's avatar
payno committed
182

183
184
185
186
187
188
189
190
            manager = Manager()
            output_dict = {}
            res_list = manager.list()
            for i_spect, spect in enumerate(xas_obj.spectra):
                res_list.append(None)
                output_dict[spect] = i_spect

            with multiprocessing.Pool(5) as p:
payno's avatar
payno committed
191
192
193
194
195
196
197
198
                partial_ = functools.partial(
                    process_spectr_mback_norm,
                    configuration=xas_obj.configuration,
                    callback=self._advancement.increaseAdvancement,
                    overwrite=False,
                    output=res_list,
                    output_dict=output_dict,
                )
199
200
201
202
203
204
205
206
207
208
209
                p.map(partial_, xas_obj.spectra)

            # then update local spectrum
            for spectrum, res in zip(xas_obj.spectra, res_list):
                spectrum.update(res)

    def definition(self):
        return "mback norm calculation"

    def program_version(self):
        import larch.version
payno's avatar
payno committed
210
211

        return larch.version.version_data()["larch"]
212

213
214
    @staticmethod
    def program_name():
payno's avatar
payno committed
215
        return "larch_mback_norm"
216

217
    __call__ = run