Skip to content
types.py 24.7 KiB
Newer Older
# coding: utf-8
# /*##########################################################################
#
# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/

__authors__ = ["H. Payno"]
__license__ = "MIT"
__date__ = "06/11/2019"


from silx.io.dictdump import dicttoh5, h5todict
from silx.io.url import DataUrl
import numpy
import logging
_logger = logging.getLogger(__name__)
class XASObject(object):
    """Base class of XAS

    :param spectra: absorbed beam as a list of :class:`.Spectrum` or a
                    numpy.ndarray
    :type: Union[numpy.ndarray, list]
    :param energy: beam energy
    :type: numpy.ndarray of one dimension
    :param dict configuration: configuration of the different process
    :param int dim1: first dimension of the spectra
    :param int dim2: second dimension of the spectra
    :param str name: name of the object. Will be used for the hdf5 entry
    :param bool keep_process_flow: if True then will keep the trace of the set
                                   of process applied to the XASObject into a
                                   hdf5 file.
    """
    def __init__(self, spectra=None, energy=None, configuration=None, dim1=None,
                 dim2=None, name='scan1'):
        self.__channels = None
        self.__spectra = []
        self.__energy = None
        self.__dim1 = 0
        self.__dim2 = 0
        self.__processing_index = 0
        self.__h5_file = None
        self.__entry_name = name
        self.spectra = (energy, spectra, dim1, dim2)
        self.configuration = configuration
    @property
    def entry(self):
        return self.__entry_name

    def spectra(self):
        return self.__spectra
    @spectra.setter
    def spectra(self, energy_spectra):
        energy, spectra, dim1, dim2 = energy_spectra
        if spectra is None:
            self.__spectra = []
            self.__energy = energy
        else:
            assert energy is not None
            self.__spectra.clear()
            assert isinstance(spectra, (list, tuple, numpy.ndarray))
            if isinstance(spectra, numpy.ndarray):
                assert spectra.ndim is 3
                self.__dim1 = spectra.shape[1]
                self.__dim2 = spectra.shape[2]
                for y_i_spectrum in range(spectra.shape[1]):
                    for x_i_spectrum in range(spectra.shape[2]):
                        self.addSpectrum(Spectrum(energy=energy,
                                                  mu=spectra[:, y_i_spectrum, x_i_spectrum]))
            else:
                if dim1 is None or dim2 is None:
                    raise ValueError(
                        'If you want to set spectra from a list/tuple '
                        'of Spectrum you should specify the spectra '
                        'dimensions')
                self.__dim1 = dim1
                self.__dim2 = dim2
                for spectrum in spectra:
                    assert isinstance(spectrum, Spectrum)
                    self.addSpectrum(spectrum)
        self.energy = energy
    def _setSpectra(self, spectra):
        self.__spectra = spectra

    def getSpectrum(self, dim1_idx, dim2_idx):
        """Util function to access the spectrum at dim1_idx, dim2_idx"""
        assert dim1_idx < self.dim1
        assert dim2_idx < self.dim2
        global_idx = dim1_idx * self.dim2 + dim2_idx
        assert global_idx < len(self.spectra)
        assert global_idx >= 0
        return self.spectra[global_idx]

    def addSpectrum(self, spectrum):
        self.__spectra.append(spectrum)
    @property
    def dim1(self):
        return self.__dim1
payno's avatar
payno committed
    def forceDim1(self, value):
        assert type(value) is int
        self.__dim1 = value

    def forceDim2(self, value):
        assert type(value) is int
        self.__dim2 = value

    @property
    def dim2(self):
        return self.__dim2
    def energy(self):
        return self.__energy
    @energy.setter
    def energy(self, energy):
        self.__energy = energy
        if len(self.__spectra) > 0:
            if len(self.__spectra[0].energy) != len(energy):
                _logger.warning('spectra and energy have incoherent dimension')

    @property
    def configuration(self):
        return self.__configuration

    @configuration.setter
    def configuration(self, configuration):
        assert configuration is None or isinstance(configuration, dict)
        self.__configuration = configuration or {}

    def to_dict(self, with_process_details=True):
        """convert the XAS object to a dict

        By default made to simply import raw data.

        :param with_process_details: used to embed a list of spectrum with
                                     intermediary result instead of only raw mu.
                                     This is needed especially for the
                                     pushworkflow actors to keep a trace of the
                                     processes.
        :type: bool
        """
        def get_list_spectra():
            res = []
            for spectrum in self.spectra:
                res.append(spectrum.to_dict())
            return res
        res = {
            'configuration': self.configuration,
            'spectra': XASObject._spectra_volume(spectra=self.spectra,
                                                 key='Mu',
                                                 dim_1=self.dim1,
                                                 dim_2=self.dim2),
            'energy': self.energy,
            'dim1': self.dim1,
            'dim2': self.dim2,
        if with_process_details is True:
            res['spectra'] = get_list_spectra()
            res['linked_h5_file'] = self.linked_h5_file
            res['current_processing_index'] = self.__processing_index

        return res

    def _spectra_to_dict(self):
        spectra_dict = {}
        for i_spectrum, spectrum in enumerate(self.spectra):
            assert isinstance(spectrum, Spectrum)
            spectra_dict[str(i_spectrum) + '_spectrum'] = spectrum.to_dict()
        return spectra_dict
    def absorbed_beam(self):
        return XASObject._spectra_volume(spectra=self.spectra,
                                         key='Mu',
                                         dim_1=self.dim1,
                                         dim_2=self.dim2)

    def _spectra_volume(spectra, key, dim_1, dim_2):
        """Convert a list of spectra (mu) to a numpy array.
        ..note: only convert raw data for now"""
        if len(spectra) is 0:
            return None
        else:
payno's avatar
payno committed
            assert len(spectra) == dim_1 * dim_2
            array = numpy.zeros((len(spectra[0].energy), dim_1 * dim_2))
            for i_spectrum, spectrum in enumerate(spectra):
                subkeys = key.split('/')
                value = spectrum[subkeys[0]]
                for subkey in subkeys[1:]:
                    value = value[subkey]
                array[:, i_spectrum] = value

            return array.reshape((len(spectra[0].energy), dim_1, dim_2))

    def load_frm_dict(self, ddict):
        """load XAS values from a dict"""
        contains_config_spectrum = 'configuration' in ddict or 'spectra' in ddict
        """The dict can be on the scheme of the to_dict function, containing
        the spectra and the configuration. Otherwise we consider it is simply
        the spectra"""
        if 'configuration' in ddict:
            self.configuration = ddict['configuration']
        if 'spectra' in ddict:
            spectra = ddict['spectra']
            if not isinstance(spectra, numpy.ndarray):
                new_spectra = []
                for spectrum in spectra:
                    assert isinstance(spectrum, dict)
                    new_spectra.append(Spectrum.from_dict(spectrum))
                spectra = new_spectra
        else:
            spectra = None
        if 'energy' in ddict:
            energy = ddict['energy']
        else:
            energy = None
        if 'dim1' in ddict:
            dim1 = ddict['dim1']
        else:
            dim1 = None
        if 'dim2' in ddict:
            dim2 = ddict['dim2']
        else:
            dim2 = None
        if 'linked_h5_file' in ddict:
            assert 'current_processing_index' in ddict
            self.link_to_h5(ddict['linked_h5_file'])
            self.__processing_index = ddict['current_processing_index']

        self.spectra = (energy, spectra, dim1, dim2)

        if not contains_config_spectrum:
            self.spectrum = ddict
        return self

    @staticmethod
    def from_dict(ddict):
        return XASObject().load_frm_dict(ddict=ddict)

    @staticmethod
    def from_file(h5_file, entry='scan1', spectra_path='data/absorbed_beam',
                  energy_path='data/energy', configuration_path='configuration'):
        # load only mu and energy from the file
        import xas.io
        spectra_url = DataUrl(file_path=h5_file,
                              data_path='/'.join((entry, spectra_path)),
                              scheme='silx')
        energy_url = DataUrl(file_path=h5_file,
                             data_path='/'.join((entry, energy_path)),
                             scheme='silx')
        if configuration_path is None:
            config_url = None
        else:
            config_url = DataUrl(file_path=h5_file,
                                 data_path='/'.join((entry, configuration_path)),
                                 scheme='silx')
        return xas.io.read_pymca_xas(spectra_url=spectra_url,
                                     channel_url=energy_url,
                                     config_url=config_url)
    def dump(self, h5_file):
        """dump the XAS object to a file_path within the Nexus format"""
        dicttoh5(treedict=self.to_dict(with_process_details=False),
                 h5file=h5_file)
    def copy(self):
        return copy.copy(self)

        return (isinstance(other, XASObject) and
                numpy.array_equal(self.energy, other.energy) and
                self.dim1 == other.dim1 and
                self.dim2 == other.dim2 and
                self.configuration == other.configuration and
                self.spectra_equal(self.spectra, other.spectra))

    @staticmethod
    def spectra_equal(spectra1, spectra2):
        if len(spectra1) != len(spectra2):
            return False
        else:
            for i_spectrum, spectrum in enumerate(spectra1):
                if not numpy.array_equal(spectrum.mu, spectra2[i_spectrum].mu):
                    return False
            return True

    @property
    def n_spectrum(self):
        """return the number of spectra"""
        if self.__spectra is None:
            return 0
        else:
            return len(self.__spectra)

payno's avatar
payno committed
    def spectra_keys(self):
        """keys contained by the spectrum object (energy, mu, normalizedmu...)
        """
        if len(self.spectra) > 0:
            assert isinstance(self.spectra[0], Spectrum)
            return self.spectra[0].keys()

    @property
    def linked_h5_file(self):
        return self.__h5_file

    def link_to_h5(self, h5_file):
        """
        Associate a .h5 file to the XASObject. This can be used for storing
        process flow.
        
        :param h5_file: 
        :return: 
        """
        self.__h5_file = h5_file

    def has_linked_file(self):
        return self.__h5_file is not None

    def get_next_processing_index(self):
        self.__processing_index += 1
        return self.__processing_index

    def register_processing(self, process, data):
        """
        Register one process for the current xas object. This require to having
        link a h5file to this object
        
        :param :class:`.Process` process: 
        :param data: result of the processing. If there is more than one
                       result then a dictionary with the key under which result
                       should be saved and a numpy.ndarray
        :type: Union[numpy.ndarray, dict]
        """
        import xas.io
        xas.io.write_xas_proc(self.linked_h5_file, entry=self.__entry_name,
                              processing_order=self.get_next_processing_index(),
                              process=process, data=data)

    def get_process_flow(self):
        """
        
        :return: the dict of process information
        :rtype: dict
        """
        import xas.io

        if not self.linked_h5_file:
            _logger.warning('process flow is store in the linked .h5 file. If'
                            'no link is defined then this information is not'
                            'stored')
            return {}
        else:
            recognized_process = xas.io.get_xasproc(self.linked_h5_file,
                                                    entry=self.__entry_name)
            know_process = ('pymca_normalization', 'pymca_exafs', 'pymca_ft',
                            'pymca_k_weight')

            def filter_recognized_process(process_list):
                res = []
                for process_ in process_list:
                    if 'program' in process_.keys() and process_['program'] in know_process:
                        res.append(process_)
                return res
            recognized_process = filter_recognized_process(recognized_process)

            def get_ordered_process(process_list):
                res = {}
                for process_ in process_list:
                    if not 'processing_order' in process_:
                        _logger.warning('one processing has not processing order: ' + process_['program'])
                    else:
                        processing_order = int(process_['processing_order'])
                        res[processing_order] = process_
                return res

            return get_ordered_process(recognized_process)

    def clean_process_flow(self):
        """
        Remove existing process flow
        """
        if not self.linked_h5_file:
            _logger.warning('process flow is store in the linked .h5 file. If'
                            'no link is defined then this information is not'
                            'stored')
        else:
            process_flow = self.get_process_flow()
            with h5py.File(self.linked_h5_file) as h5f:
                for index, process_ in process_flow.items():
                    del h5f[process_['_h5py_path']]

    def copy_process_flow_to(self, h5_file_target):
        """
        copy all the recognized process from self.__h5_file to h5_file_target

        :param str h5_file_target: path to the targeted file. Should be an 
                                   existing hdf5 file.
        """
        assert os.path.exists(h5_file_target)
        assert h5py.is_hdf5(h5_file_target)

        flow = self.get_process_flow()
        entry = self.entry
        with h5py.File(self.__h5_file) as source_hdf:
            with h5py.File(h5_file_target) as target_hdf:
                target_entry = target_hdf.require_group(entry)
                def remove_entry_prefix(name):
                    return name.replace('/'+entry+'/', '', 1)
                for process_id, process in flow.items():
                    process_path = process['_h5py_path']
                    dst_path = remove_entry_prefix(name=process_path)
                    target_entry.copy(source=source_hdf[process_path],
                                      dest=dst_path)
# TODO: add the spectra class. Would speed up and simplify stuff probably
class Spectra(object):
    pass

class Spectrum(object):
    """
    set of curve (one dimensional numpy.ndarray) to be pass to the different xas
    treatment.

    Can be accessed as a dictionnary for non standard parameters.

    :param numpy.ndarray (1D) energy: beam energy
    :param numpy.ndarray (1D) mu: beam absorption
    """
    _MU_KEY = 'Mu'

    _ENERGY_KEY = 'Energy'

    _NORMALIZED_MU_KEY = 'NormalizedMu'

    _NORMALIZED_ENERGY_KEY = 'NormalizedEnergy'

    _NORMALIZED_SIGNAL_KEY = 'NormalizedSignal'

    _FT_KEY = 'FT'

    def __init__(self, energy=None, mu=None):
        if energy is not None:
            assert isinstance(energy, numpy.ndarray)

        # properties
        self.energy = energy
        self.mu = mu
        self.__normalized_mu = None
        self.__normalized_energy = None
        self.__normalized_signal = None
        self.__other_parameters = {}
        self.ft = {}

        self.__key_mapper = {
            self._MU_KEY: self.__class__.mu,
            self._ENERGY_KEY: self.__class__.energy,
            self._NORMALIZED_MU_KEY: self.__class__.normalized_mu,
            self._NORMALIZED_ENERGY_KEY: self.__class__.normalized_energy,
            self._NORMALIZED_SIGNAL_KEY: self.__class__.normalized_signal,
            self._FT_KEY: self.__class__.ft
        }

    @property
    def energy(self):
        return self.__energy

    @energy.setter
    def energy(self, energy):
        assert isinstance(energy, numpy.ndarray) or energy is None
        self.__energy = energy

    @property
    def mu(self):
        return self.__mu

    @mu.setter
    def mu(self, mu):
        assert isinstance(mu, numpy.ndarray) or mu is None
        self.__mu = mu

    @property
    def normalized_mu(self):
        return self.__normalized_mu

    @normalized_mu.setter
    def normalized_mu(self, mu):
        assert isinstance(mu, numpy.ndarray) or mu is None
        self.__normalized_mu = mu

    @property
    def normalized_energy(self):
        return self.__normalized_energy

    @normalized_energy.setter
    def normalized_energy(self, energy):
        assert isinstance(energy, numpy.ndarray) or energy is None
        self.__normalized_energy = energy

    @property
    def normalized_signal(self):
        return self.__normalized_signal

    @normalized_signal.setter
    def normalized_signal(self, signal):
        assert isinstance(signal, numpy.ndarray) or signal is None
        self.__normalized_signal = signal

    @property
    def ft(self):
        return self.__ft

    @ft.setter
    def ft(self, ft):
        if isinstance(ft, _FT):
            self.__ft = ft
        else:
            self.__ft = _FT(ddict=ft)

    @property
    def shape(self):
        _energy_len = 0
        if self.__energy is not None:
            _energy_len = len(self.__energy)
        _mu_len = 0
        if self.__mu is not None:
            _mu_len = len(self.__mu)

        return (_energy_len, _mu_len)

    def extra_keys(self):
        return self.__other_parameters.keys()

    def __getitem__(self, key):
        """Need for pymca compatibility"""
        if key in self.__key_mapper:
            return self.__key_mapper[key].fget(self)
        else:
            return self.__other_parameters[key]

    def __setitem__(self, key, value):
        """Need for pymca compatibility"""
        if key in self.__key_mapper:
            self.__key_mapper[key].fset(self, value)
        else:
            self.__other_parameters[key] = value

    def __contains__(self, item):
        return item in self.__key_mapper or item in self.__other_parameters

        assert isinstance(ddict, dict)
        for key, value in ddict.items():
            self[key] = value
        return self

    def update(self, spectrum):
        assert isinstance(spectrum, Spectrum)
        for key in spectrum:
            self[key] = spectrum[key]

    @staticmethod
    def from_dict(ddict):
        spectrum = Spectrum()
        return spectrum.load_frm_dict(ddict=ddict)

    def to_dict(self):
        res = {
            self._MU_KEY: self.mu,
            self._ENERGY_KEY: self.energy,
            self._FT_KEY: self.ft.to_dict(),
            self._NORMALIZED_MU_KEY: self.normalized_mu,
            self._NORMALIZED_ENERGY_KEY: self.normalized_energy,
            self._NORMALIZED_SIGNAL_KEY: self.normalized_signal,
        }
        res.update(self.__other_parameters)
        return res

    def __str__(self):
        def add_info(str_, attr):
            assert hasattr(self, attr)
            sub_str = '- ' + attr + ': ' + str(getattr(self, attr)) + '\n'
            return (str_ + sub_str)
        main_info = ""
        for info in ('energy', 'mu', 'normalized_mu', 'normalized_signal', 'normalized_energy'):
            main_info = add_info(str_=main_info, attr=info)

        def add_third_info(str_, key):
            sub_str = ('- ' + key + ': ' + str(self[key])) + '\n'
            return str_ + sub_str
        for key in self.__other_parameters:
            main_info = add_third_info(str_=main_info, key=key)
        return main_info

    def update(self, obj):
        """
        Update the contained values from the given obj.

        :param obj:
        :type obj: Union[XASObject, dict]
        """
        if isinstance(obj, Spectrum):
            _obj = obj.to_dict()
        else:
            _obj = obj
        assert isinstance(_obj, dict)
        for key, value in _obj.items():
    def get_missing_keys(self, keys):
        """Return missing keys on the spectrum"""
        missing = []
        for key in keys:
            if key not in self:
                missing.append(key)
        if len(missing) is 0:
            return None
        else:
            return missing

    def keys(self):
        keys = list(self.__other_parameters.keys())
        keys += list(self.__key_mapper.keys())
        return keys


class _FT(object):

    _RADIUS_KEY = 'FTRadius'

    _INTENSITY_KEY = 'FTIntensity'

    _IMAGINERY_KEY = 'FTImaginary'

    def __init__(self, ddict):
        self.__radius = None
        self.__intensity = None
        self.__imaginery = None
        self.__other_parameters = {}

        self.__key_mapper = {
            self._RADIUS_KEY: self.__class__.radius,
            self._INTENSITY_KEY: self.__class__.intensity,
            self._IMAGINERY_KEY: self.__class__.imaginery,
        }

payno's avatar
payno committed
        if ddict is not None:
            for key, values in ddict.items():
                self[key] = values

    @property
    def radius(self):
        return self.__radius

    @radius.setter
    def radius(self, radius):
        self.__radius = radius

    @property
    def intensity(self):
        return self.__intensity

    @intensity.setter
    def intensity(self, intensity):
        self.__intensity = intensity

    @property
    def imaginery(self):
        return self.__imaginery

    @imaginery.setter
    def imaginery(self, imaginery):
        self.__imaginery = imaginery

    def __getitem__(self, key):
        """Need for pymca compatibility"""
        if key in self.__key_mapper:
            return self.__key_mapper[key].fget(self)
        else:
            return self.__other_parameters[key]

    def __setitem__(self, key, value):
        """Need for pymca compatibility"""
        if key in self.__key_mapper:
            self.__key_mapper[key].fset(self, value)
        else:
            self.__other_parameters[key] = value
    def __contains__(self, item):
        return item in self.__key_mapper or item in self.__other_parameters
    def to_dict(self):
        res = {
            self._RADIUS_KEY: self.radius,
            self._INTENSITY_KEY: self.intensity,
            self._IMAGINERY_KEY: self.imaginery,
        }
        res.update(self.__other_parameters)
        return res

    def get_missing_keys(self, keys):
        """Return missing keys on the spectrum"""
        missing = []
        for key in keys:
            if key not in self:
                missing.append(key)
        if len(missing) is 0:
            return None
        else:
            return missing


class Sample(object):
    """Description of the sample. Needed for writing valid nx file"""
    def __init__(self, name='undefined sample', description=None):
        self.name = name
        self.description = description