Commit 15d74183 authored by Damien Naudet's avatar Damien Naudet

Refactored fit code.

parent 7364b7b5
......@@ -34,10 +34,70 @@ import numpy as np
from scipy.optimize import leastsq
from .fitresults import FitStatus
# Some constants
_const_inv_2_pi_ = np.sqrt(2 * np.pi)
class Fitter(object):
def __init__(self, qx, qy, qz,
shared_results):
super(Fitter, self).__init__()
self._shared_results = shared_results
self._qx = qx
self._qy = qy
self._qz = qz
def fit(self, i_cube, qx_profile, qy_profile, qz_profile):
raise NotImplementedError('Not implemented.')
class GaussianFitter(Fitter):
def __init__(self, *args, **kwargs):
super(GaussianFitter, self).__init__(*args, **kwargs)
self._z_0 = [1.0, self._qz.mean(), 1.0]
self._y_0 = [1.0, self._qy.mean(), 1.0]
self._x_0 = [1.0, self._qx.mean(), 1.0]
def fit(self, i_cube, qx_profile, qy_profile, qz_profile):
z_fit, success_z = gaussian_fit(self._qz, qz_profile, self._z_0)
y_fit, success_y = gaussian_fit(self._qy, qy_profile, self._y_0)
x_fit, success_x = gaussian_fit(self._qx, qx_profile, self._x_0)
self._shared_results.set_qz_results(i_cube, z_fit, success_z)
self._shared_results.set_qy_results(i_cube, y_fit, success_y)
self._shared_results.set_qx_results(i_cube, x_fit, success_x)
class CentroidFitter(Fitter):
def fit(self, i_cube, qx_profile, qy_profile, qz_profile):
com = self._qz.dot(qz_profile) / qz_profile.sum()
idx = np.abs(self._qz - com).argmin()
i_max = qz_profile.max()
self._shared_results.set_qz_results(i_cube,
[qz_profile[idx], com, i_max],
FitStatus.OK)
com = self._qy.dot(qy_profile) / qy_profile.sum()
idx = np.abs(self._qy - com).argmin()
i_max = qy_profile.max()
self._shared_results.set_qy_results(i_cube,
[qy_profile[idx], com, i_max],
FitStatus.OK)
com = self._qx.dot(qx_profile) / qx_profile.sum()
idx = np.abs(self._qx - com).argmin()
i_max = qx_profile.max()
self._shared_results.set_qx_results(i_cube,
[qx_profile[idx], com, i_max],
FitStatus.OK)
# 1d Gaussian func
# TODO : optimize
def gaussian(x, a, c, s):
......@@ -46,7 +106,7 @@ def gaussian(x, a, c, s):
:param x: values for which the gaussian must be computed
:param a: area under curve ( amplitude * s * sqrt(2 * pi) )
:param c: center
:param stdev: standard deviation
:param s: sigma
:return: (a / (sqrt(2 * pi) * s)) * exp(- 0.5 * ((x - c) / s)^2)
"""
return (a * (1. / (_const_inv_2_pi_ * s)) *
......@@ -57,7 +117,6 @@ def gaussian(x, a, c, s):
# TODO : optimize
def gaussian_fit_err(p, x, y):
"""
:param p:
:param x:
:param y:
......@@ -83,26 +142,9 @@ def gaussian_fit(x, y, p):
full_output=True)
if result[4] not in [1, 2, 3, 4]:
raise ValueError('Failed to fit : {0}.'.format(result[3]))
return result[0]
return [np.nan, np.nan, np.nan], FitStatus.FAILED
def centroid(x, y, p):
"""
Computes the center of mass of the provided data.
Returns the value closest to the center of mass, and the
the center of mass
:param x:
:param y:
:param p:
:return: list
"""
# TODO : throw exception if fit failed
com = x.dot(y) / y.sum()
idx = np.abs(x - com).argmin()
i_max = y.max()
return [y[idx], com, i_max]
return result[0], FitStatus.OK
def _gauss_first_guess(x, y):
......
......@@ -67,12 +67,14 @@ class FitResult(object):
self._processes = OrderedDict()
n_pts = len(sample_x)
self._n_pts = n_pts = len(sample_x)
self._status = OrderedDict([('qx_status', np.zeros(n_pts)),
('qy_status', np.zeros(n_pts)),
('qz_status', np.zeros(n_pts))])
self._infos = OrderedDict()
def processes(self):
"""
Returns the process names
......@@ -172,6 +174,64 @@ class FitResult(object):
param_data = self._get_param(process, param)
param_data[self._AXIS_NAMES[axis]] = result
def add_qx_info(self, name, data):
"""
Add other misc. plottable info (e.g : chi2, ...) associated with the
fit along the qx axis.
data must be an array with the same number of elements
:param name:
:param axis:
:param data:
:return:
"""
self._add_info(name, self.QX_AXIS, data)
def add_qy_info(self, name, data):
"""
Add other misc. plottable info (e.g : chi2, ...) associated with the
fit along the qy axis.
data must be an array with the same number of elements
:param name:
:param axis:
:param data:
:return:
"""
self._add_info(name, self.QY_AXIS, data)
def add_qz_info(self, name, data):
"""
Add other misc. plottable info (e.g : chi2, ...) associated with the
fit along the qz axis.
data must be an array with the same number of elements
:param name:
:param axis:
:param data:
:return:
"""
self._add_info(name, self.QZ_AXIS, data)
def _add_info(self, name, axis, data):
"""
Add other misc. plottable info (e.g : chi2, ...) associated with the
fit along the given axis.
data must be an array with the same number of elements
:param name:
:param axis:
:param data:
:return:
"""
assert axis in self._AXIS
infos = self._infos
if name not in infos:
info = OrderedDict([('qx', None), ('qy', None), ('qz', None)])
infos[name] = info
else:
info = infos[name]
info[self._AXIS_NAMES[axis]] = data
def set_qx_status(self, status):
self._set_axis_status(self.QX_AXIS, status)
......
......@@ -40,7 +40,7 @@ import numpy as np
# from silx.math import curve_fit
from ..io import QSpaceH5
from .fit_funcs import gaussian_fit, centroid
from .fit_funcs import GaussianFitter, CentroidFitter
from .sharedresults import FitTypes, GaussianResults, CentroidResults
from .fitresults import FitStatus
......@@ -78,6 +78,7 @@ class PeakFitter(Thread):
self.__results = None
self.__thread = None
self.__progress = 0
self.__callback = None
self.__status = self.READY
......@@ -183,10 +184,10 @@ class PeakFitter(Thread):
# # success[:] = True
if fit_type == FitTypes.GAUSSIAN:
fit_fn = gaussian_fit
fit_class = GaussianFitter
shared_results = GaussianResults(n_points=n_indices)
if fit_type == FitTypes.CENTROID:
fit_fn = centroid
fit_class = CentroidFitter
shared_results = CentroidResults(n_points=n_indices)
# with h5py.File(qspace_f, 'r') as qspace_h5:
......@@ -229,18 +230,11 @@ class PeakFitter(Thread):
initializer=_init_thread,
initargs=(shared_results,
shared_progress,
fit_fn,
fit_class,
(n_indices, 9),
idx_queue,
qspace_f,
read_lock))
# initargs=(shared_res,
# shared_success,
# fit_fn,
# (n_indices, 9),
# idx_queue,
# qspace_f,
# read_lock))
if disp_times:
class myTimes(object):
......@@ -321,14 +315,14 @@ class PeakFitter(Thread):
def _init_thread(shared_res_,
shared_prog_,
fit_fn_,
fit_class_,
result_shape_,
idx_queue_,
qspace_f_,
read_lock_):
global shared_res, \
shared_progress, \
fit_fn, \
fit_class, \
result_shape, \
idx_queue, \
qspace_f, \
......@@ -336,7 +330,7 @@ def _init_thread(shared_res_,
shared_res = shared_res_
shared_progress = shared_prog_
fit_fn = fit_fn_
fit_class = fit_class_
result_shape = result_shape_
idx_queue = idx_queue_
qspace_f = qspace_f_
......@@ -350,13 +344,8 @@ def _fit_process(th_idx, roiIndices=None):
t_fit = 0.
t_mask = 0.
# results = np.frombuffer(shared_res)
# results.shape = result_shape
# success = np.frombuffer(shared_success, dtype=bool)
l_shared_res = shared_res.local_copy()
progress = np.frombuffer(shared_progress, dtype='int32')
# results = l_shared_res._shared_array
# success = l_shared_res._shared_status
qspace_h5 = QSpaceH5.QSpaceH5(qspace_f)
......@@ -368,15 +357,7 @@ def _fit_process(th_idx, roiIndices=None):
# TODO : timeout to check if it has been canceled
# read_lock.acquire()
# with h5py.File(qspace_f, 'r') as qspace_h5:
# q_x = qspace_h5['bins_edges/x'][:]
# q_y = qspace_h5['bins_edges/y'][:]
# q_z = qspace_h5['bins_edges/z'][:]
# q_shape = qspace_h5['data/qspace'].shape
# q_dtype = qspace_h5['data/qspace'].dtype
# mask = np.where(qspace_h5['histo'][:] > 0)
# weights = qspace_h5['histo'][:][mask]
with qspace_h5 as qspace_h5:
with qspace_h5:
q_x = qspace_h5.qx
q_y = qspace_h5.qy
q_z = qspace_h5.qz
......@@ -395,13 +376,10 @@ def _fit_process(th_idx, roiIndices=None):
weights = histo[mask]
# read_lock.release()
# print weights.max(), min(weights)
read_cube = np.ascontiguousarray(np.zeros(q_shape[1:]),
dtype=q_dtype)
x_0 = None
y_0 = None
z_0 = None
fitter = fit_class(q_x, q_y, q_z, l_shared_res)
while True:
# TODO : timeout
......@@ -417,10 +395,6 @@ def _fit_process(th_idx, roiIndices=None):
'Processing cube {0}/{1}.'.format(i_cube, result_shape[0]))
t0 = time.time()
# with h5py.File(qspace_f, 'r') as qspace_h5:
# qspace_h5['data/qspace'].read_direct(cube,
# source_sel=np.s_[i_cube],
# dest_sel=None)
with qspace_h5.qspace_dset_ctx() as dset:
dset.read_direct(read_cube,
source_sel=np.s_[i_cube],
......@@ -438,97 +412,21 @@ def _fit_process(th_idx, roiIndices=None):
t0 = time.time()
success_x = FitStatus.OK
success_y = FitStatus.OK
success_z = FitStatus.OK
z_sum = cube.sum(axis=0).sum(axis=0)
# if z_0 is None:
# z_0 = _gauss_first_guess(q_z, z_sum)
z_0 = [1.0, q_z.mean(), 1.0]
try:
fit_z = fit_fn(q_z, z_sum, z_0)
z_0 = fit_z
except Exception as ex:
# print('Z Failed', ex)
z_0 = None
fit_z = [np.nan, np.nan, np.nan]
success_z = FitStatus.FAILED
l_shared_res.set_qz_results(i_cube, fit_z, success_z)
z_sum = 0
cube_sum_z = cube.sum(axis=2)
y_sum = cube_sum_z.sum(axis=0)
# if y_0 is None:
# y_0 = _gauss_first_guess(q_y, y_sum)
y_0 = [1.0, q_y.mean(), 1.0]
try:
fit_y = fit_fn(q_y, y_sum, y_0)
y_0 = fit_y
except Exception as ex:
# print('Y Failed', ex, i_cube)
y_0 = None
fit_y = [np.nan, np.nan, np.nan]
success_y = FitStatus.FAILED
l_shared_res.set_qy_results(i_cube, fit_y, success_y)
y_sum = 0
x_sum = cube_sum_z.sum(axis=1)
# if x_0 is None:
# x_0 = _gauss_first_guess(q_x, x_sum)
x_0 = [1.0, q_x.mean(), 1.0]
try:
fit_x = fit_fn(q_x, x_sum, x_0)
x_0 = fit_x
except Exception as ex:
# print('X Failed', ex)
x_0 = None
fit_x = [np.nan, np.nan, np.nan]
success_x = FitStatus.FAILED
l_shared_res.set_qx_results(i_cube, fit_x, success_x)
x_sum = 0
fitter.fit(i_cube, x_sum, y_sum, z_sum)
t_fit += time.time() - t0
t0 = time.time()
# success[i_cube] = True
#
# if success_x:
# results[i_cube, 0:3] = fit_x
# else:
# results[i_cube, 0:3] = np.nan
# success[i_cube] = False
#
# if success_y:
# results[i_cube, 3:6] = fit_y
# else:
# results[i_cube, 3:6] = np.nan
# success[i_cube] = False
#
# if success_z:
# results[i_cube, 6:9] = fit_z
# else:
# results[i_cube, 6:9] = np.nan
# success[i_cube] = False
t_write = time.time() - t0
except Exception as ex:
print 'EX', ex
print('EX', ex)
times = (t_read, t_mask, t_fit, t_write)
if disp_times:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment