Commit 55e5aeca authored by Thomas Vincent's avatar Thomas Vincent
Browse files

rewrite FitResults

parent 3b1c1bac
#!/usr/bin/python
# coding: utf8
# /*##########################################################################
#
# Copyright (c) 2015-2016 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/
from __future__ import absolute_import
__authors__ = ["D. Naudet"]
__date__ = "01/06/2016"
__license__ = "MIT"
from collections import OrderedDict
import numpy as np
from ...io.FitH5 import FitH5Writer, FitH5QAxis
class FitStatus(object):
"""
Enum for the fit status
Starting at 1 for compatibility reasons.
"""
UNKNOWN, OK, FAILED = range(0, 3)
class FitResult(object):
"""
Fit results
"""
_AXIS = QX_AXIS, QY_AXIS, QZ_AXIS = range(3)
_AXIS_NAMES = ('qx', 'qy', 'qz')
def __init__(self, entry,
q_x, q_y, q_z,
sample_x, sample_y,
background_mode):
super(FitResult, self).__init__()
self._entry = entry
self._sample_x = sample_x
self._sample_y = sample_y
self._q_x = q_x
self._q_y = q_y
self._q_z = q_z
self._background_mode = background_mode
self._processes = OrderedDict()
self._n_pts = len(sample_x)
self._status = OrderedDict([('qx_status', np.zeros(self._n_pts)),
('qy_status', np.zeros(self._n_pts)),
('qz_status', np.zeros(self._n_pts))])
self._infos = OrderedDict()
def processes(self):
"""
Returns the process names
:return:
"""
return self._processes.keys()
def params(self, process):
return self._get_process(process, create=False)['params'].keys()
def status(self, axis):
"""
Returns the status for the given axis.
:param axis:
:return:
"""
assert axis in self._AXIS
return self._status[self._AXIS_NAMES[axis]][:]
def qx_status(self):
"""
Returns qx fit status
:return:
"""
return self.status(self.QX_AXIS)
def qy_status(self):
"""
Returns qy fit status
:return:
"""
return self.status(self.QY_AXIS)
def qz_status(self):
"""
Returns qz fit status
:return:
"""
return self.status(self.QZ_AXIS)
def results(self, process, param, axis=None):
"""
Returns the fitted parameter results for a given process.
:param process: process name
:param param: param name
:param axis: if provided, returns only the result for the given axis
:return:
"""
param = self._get_param(process, param, create=False)
if axis is not None:
assert axis in self._AXIS
return param[self._AXIS_NAMES[axis]]
return param
def qx_results(self, process, param):
"""
Returns qx fit results for the given process
:param process:
:param param: param name
:return:
"""
return self.results(process, param, axis=self.QX_AXIS)
def qy_results(self, process, param):
"""
Returns qy fit results for the given process
:param process:
:param param: param name
:return:
"""
return self.results(process, param, axis=self.QY_AXIS)
def qz_results(self, process, param):
"""
Returns qz fit results for the given process
:param process:
:param param: param name
:return:
"""
return self.results(process, param, axis=self.QZ_AXIS)
def add_result(self, axis, process, param, result):
self._add_axis_result(process, self._AXIS_NAMES.index(axis), param, result)
def _add_axis_result(self, process, axis, param, result):
assert axis in self._AXIS
param_data = self._get_param(process, param)
param_data[self._AXIS_NAMES[axis]] = result
def add_qx_info(self, name, data):
"""
Add other misc. plottable info (e.g : chi2, ...) associated with the
fit along the qx axis.
data must be an array with the same number of elements
:param name:
:param data:
:return:
"""
self._add_info(name, self.QX_AXIS, data)
def add_qy_info(self, name, data):
"""
Add other misc. plottable info (e.g : chi2, ...) associated with the
fit along the qy axis.
data must be an array with the same number of elements
:param name:
:param data:
:return:
"""
self._add_info(name, self.QY_AXIS, data)
def add_qz_info(self, name, data):
"""
Add other misc. plottable info (e.g : chi2, ...) associated with the
fit along the qz axis.
data must be an array with the same number of elements
:param name:
:param data:
:return:
"""
self._add_info(name, self.QZ_AXIS, data)
def _add_info(self, name, axis, data):
"""
Add other misc. plottable info (e.g : chi2, ...) associated with the
fit along the given axis.
data must be an array with the same number of elements
:param name:
:param axis:
:param data:
:return:
"""
assert axis in self._AXIS
infos = self._infos
if name not in infos:
info = OrderedDict([('qx', None), ('qy', None), ('qz', None)])
infos[name] = info
else:
info = infos[name]
info[self._AXIS_NAMES[axis]] = data
def set_status(self, axis, status):
self._set_axis_status(self._AXIS_NAMES.index(axis), status)
def _set_axis_status(self, axis, status):
assert axis in self._AXIS
self._status[self._AXIS_NAMES[axis]] = status
def _get_process(self, process, create=True):
if process not in self._processes:
if not create:
raise KeyError('Unknown process {0}.'.format(process))
_process = OrderedDict([('params', OrderedDict())])
self._processes[process] = _process
else:
_process = self._processes[process]
return _process
def _get_param(self, process, param, create=True):
process = self._get_process(process, create=create)
params = process['params']
if param not in params:
if not create:
raise KeyError('Unknown param {0}.'.format(param))
_param = OrderedDict()
for axis in self._AXIS_NAMES:
_param[axis] = None
params[param] = _param
else:
_param = params[param]
return _param
entry = property(lambda self: self._entry)
sample_x = property(lambda self: self._sample_x)
sample_y = property(lambda self: self._sample_y)
q_x = property(lambda self: self._q_x)
q_y = property(lambda self: self._q_y)
q_z = property(lambda self: self._q_z)
def to_fit_h5(self, fit_h5, mode=None):
with FitH5Writer(fit_h5, mode=mode) as fitH5:
entry = self.entry
fitH5.create_entry(entry)
fitH5.set_scan_x(entry, self.sample_x)
fitH5.set_scan_y(entry, self.sample_y)
fitH5.set_qx(entry, self.q_x)
fitH5.set_qy(entry, self.q_y)
fitH5.set_qz(entry, self.q_z)
fitH5.set_background_mode(entry, self._background_mode)
processes = self.processes()
for process in processes:
fitH5.create_process(entry, process)
for param in self.params(process):
xresult = self.results(process, param,
self.QX_AXIS)
yresult = self.results(process, param,
self.QY_AXIS)
zresult = self.results(process, param,
self.QZ_AXIS)
fitH5.set_qx_result(entry,
process,
param,
xresult)
fitH5.set_qy_result(entry,
process,
param,
yresult)
fitH5.set_qz_result(entry,
process,
param,
zresult)
xstatus = self.qx_status()
ystatus = self.qy_status()
zstatus = self.qz_status()
fitH5.set_status(entry,
FitH5QAxis.qx_axis,
xstatus)
fitH5.set_status(entry,
FitH5QAxis.qy_axis,
ystatus)
fitH5.set_status(entry,
FitH5QAxis.qz_axis,
zstatus)
...@@ -34,16 +34,15 @@ import logging ...@@ -34,16 +34,15 @@ import logging
import functools import functools
import multiprocessing import multiprocessing
import numpy as np import numpy
from scipy.optimize import leastsq from scipy.optimize import leastsq
from silx.math.fit import snip1d from silx.math.fit import snip1d
from ... import config from ... import config
from ...io import QSpaceH5 from ...io import QSpaceH5
from ...io.FitH5 import BackgroundTypes from ...io.FitH5 import BackgroundTypes, FitH5Writer
from ...util import gaussian, project from ...util import gaussian, project
from .fitresults import FitResult
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -61,18 +60,18 @@ def background_estimation(mode, data): ...@@ -61,18 +60,18 @@ def background_estimation(mode, data):
# Background subtraction # Background subtraction
if mode == BackgroundTypes.CONSTANT: if mode == BackgroundTypes.CONSTANT:
# Shift data so that smallest value is 0 # Shift data so that smallest value is 0
return np.ones_like(data) * np.nanmin(data) return numpy.ones_like(data) * numpy.nanmin(data)
elif mode == BackgroundTypes.LINEAR: elif mode == BackgroundTypes.LINEAR:
# Simple linear background # Simple linear background
return np.linspace(data[0], data[-1], num=len(data), endpoint=True) return numpy.linspace(data[0], data[-1], num=len(data), endpoint=True)
elif mode == BackgroundTypes.SNIP: elif mode == BackgroundTypes.SNIP:
# Using snip background # Using snip background
return snip1d(data, snip_width=len(data)) return snip1d(data, snip_width=len(data))
elif mode == BackgroundTypes.NONE: elif mode == BackgroundTypes.NONE:
return np.zeros_like(data) return numpy.zeros_like(data)
else: else:
raise ValueError("Unsupported background mode") raise ValueError("Unsupported background mode")
...@@ -84,6 +83,127 @@ class FitTypes(object): ...@@ -84,6 +83,127 @@ class FitTypes(object):
GAUSSIAN, CENTROID = ALLOWED GAUSSIAN, CENTROID = ALLOWED
class FitStatus(object):
"""
Enum for the fit status
Starting at 1 for compatibility reasons.
"""
UNKNOWN, OK, FAILED = range(0, 3)
class FitResult(object):
"""Object storing fit/com results
It also allows to save as hdf5.
:param numpy.ndarray sample_x: N X sample position of the results
:param numpy.ndarray sample_y: N Y sample position of the results
:param List[numpy.ndarray] q_dim_values:
Values along each axis of the QSpace
:param List[str] q_dim_names:
Name of axes for each dimension of the QSpace
:param FitTypes fit_mode: Kind of fit
:param BackgroundTypes background_mode: Kind of background subtraction
:param numpy.ndarray fit_results:
The fit/com results as a N (points) x 3 (axes) array of struct
containing the results.
Warning: This array is used as is and not copied.
"""
def __init__(self,
sample_x, sample_y,
q_dim_values,
q_dim_names,
fit_mode, background_mode,
fit_results):
super(FitResult, self).__init__()
self.sample_x = sample_x
"""X position on the sample of each fit result (numpy.ndarray)"""
self.sample_y = sample_y
"""Y position on the sample of each fit result (numpy.ndarray)"""
self.qspace_dimension_values = q_dim_values
"""QSpace axis values (List[numpy.ndarray])"""
self.qspace_dimension_names = q_dim_names
"""QSpace axis names (List[str])"""
self.fit_mode = fit_mode
"""Fit type (FitTypes)"""
self.background_mode = background_mode
"""Background type (BackgroundTypes)"""
# transpose from N (points) x 3 (axes) to 3 (axes) x N (points)
self._fit_results = numpy.transpose(fit_results)
@property
def available_results(self, dimension=None):
"""Returns the available result names
:param Union[int,None] dimension:
:rtype: List[str]
"""
if dimension is None:
dimension = 0
return self._fit_results[dimension].dtype.names
def get_results(self, dimension, parameter, copy=True):
"""Returns a given parameter of the result
:param int dimension: QSpace dimension from which to return result
:param str parameter: Name of the result to return
:param bool copy: True to return a copy, False to return internal data
:return: A 1D array
:rtype: numpy.ndarray
"""
return numpy.array(self._fit_results[dimension][parameter], copy=copy)
def to_fit_h5(self, fit_h5, mode=None):
"""Write fit results to an HDF5 file
:param str fit_h5: Filename where to save fit results
:param Union[None,str] mode: HDF5 file opening mode
"""
if self.fit_mode == FitTypes.GAUSSIAN:
fit_name = 'Gaussian'
result_name = 'gauss_0'
elif self.fit_mode == FitTypes.CENTROID:
fit_name = 'Centroid'
result_name = 'centroid'
else:
raise RuntimeError('Unknown Fit Type')
with FitH5Writer(fit_h5, mode=mode) as fitH5:
fitH5.create_entry(fit_name)
fitH5.set_scan_x(fit_name, self.sample_x)
fitH5.set_scan_y(fit_name, self.sample_y)
q_dim0, q_dim1, q_dim2 = self.qspace_dimension_values
fitH5.set_qx(fit_name, q_dim0)
fitH5.set_qy(fit_name, q_dim1)
fitH5.set_qz(fit_name, q_dim2)
fitH5.set_background_mode(fit_name, self.background_mode)
fitH5.create_process(fit_name, result_name)
for array, func, axis in zip(
self._fit_results,
(fitH5.set_qx_result, fitH5.set_qy_result, fitH5.set_qz_result),
(0, 1, 2)):
for name in self.available_results:
results = self.get_results(axis, name, copy=False)
if name == 'Status':
fitH5.set_status(fit_name, axis, results)
else:
func(fit_name, result_name, name, results)
class PeakFitter(object): class PeakFitter(object):
"""Class performing fit/com processing """Class performing fit/com processing
...@@ -120,7 +240,7 @@ class PeakFitter(object): ...@@ -120,7 +240,7 @@ class PeakFitter(object):
self.__n_proc = n_proc if n_proc else config.DEFAULT_PROCESS_NUMBER self.__n_proc = n_proc if n_proc else config.DEFAULT_PROCESS_NUMBER
if roi_indices is not None: if roi_indices is not None:
self.__roi_indices = np.array(roi_indices[:]) self.__roi_indices = numpy.array(roi_indices[:])
else: else:
self.__roi_indices = None self.__roi_indices = None
...@@ -142,9 +262,9 @@ class PeakFitter(object): ...@@ -142,9 +262,9 @@ class PeakFitter(object):
if indices is None: if indices is None:
n_points = qdata_shape[0] n_points = qdata_shape[0]
self.__indices = np.arange(n_points) self.__indices = numpy.arange(n_points)
else: else:
self.__indices = np.array(indices, copy=True) self.__indices = numpy.array(indices, copy=True)
def __set_status(self, status): def __set_status(self, status):
assert status in self.__STATUSES assert status in self.__STATUSES
...@@ -195,6 +315,7 @@ class PeakFitter(object): ...@@ -195,6 +315,7 @@ class PeakFitter(object):
x_pos = qspace_h5.sample_x[self.__indices] x_pos = qspace_h5.sample_x[self.__indices]
y_pos = qspace_h5.sample_y[self.__indices] y_pos = qspace_h5.sample_y[self.__indices]
q_dim0, q_dim1, q_dim2 = qspace_h5.qspace_dimension_values q_dim0, q_dim1, q_dim2 = qspace_h5.qspace_dimension_values
q_dim_names = qspace_h5.qspace_dimension_names
if self.__roi_indices is not None: if self.__roi_indices is not None:
q_dim0 = q_dim0[self.__roi_indices[0][0]:self.__roi_indices[0][1]] q_dim0 = q_dim0[self.__roi_indices[0][0]:self.__roi_indices[0][1]]
...@@ -202,41 +323,30 @@ class PeakFitter(object): ...@@ -202,41 +323,30 @@ class PeakFitter(object):
q_dim2 = q_dim2[self.__roi_indices[2][0]:self.__roi_indices[2][1]] q_dim2 = q_dim2[self.__roi_indices[2][0]:self.__roi_indices[2][1]]
if self.__fit_type == FitTypes.GAUSSIAN: if self.__fit_type == FitTypes.GAUSSIAN:
fit_name = 'Gaussian' result_dtype = [('Area', numpy.float64),
result_name = 'gauss_0' ('Center', numpy.float64),
result_dtype = [('Area', np.float64), ('Sigma', numpy.float64),
('Center', np.float64), ('Status', numpy.bool_)]
('Sigma', np.float64),
('Status', np.bool_)]
elif self.__fit_type == FitTypes.CENTROID: elif self.__fit_type == FitTypes.CENTROID:
fit_name = 'Centroid' result_dtype = [('COM', numpy.float64),
result_name = 'centroid' ('I_sum', numpy.float64),
result_dtype = [('COM', np.float64), ('I_max', numpy.float64),
('I_sum', np.float64), ('Pos_max', numpy.float64),
('I_max', np.float64), ('Status', numpy.bool_)]
('Pos_max', np.float64),
('Status', np.bool_)]
else: else:
raise RuntimeError('Unknown Fit Type') raise RuntimeError('Unknown Fit Type')
results = FitResult(entry=fit_name, self.__results = FitResult(
sample_x=x_pos, sample_x=x_pos,
sample_y=y_pos, sample_y=y_pos,
q_x=q_dim0, q_dim_values=(q_dim0, q_dim1, q_dim2),
q_y=q_dim1, q_dim_names=q_dim_names,
q_z=q_dim2, fit_mode=self.__fit_type,
background_mode=self.__background) background_mode=self.__background,
fit_results = np.array(fit_results, dtype=result_dtype) fit_results=numpy.array(fit_results, dtype=result_dtype))
# From points x axes to axes x points