Commit c6a7dcc7 authored by Damien Naudet's avatar Damien Naudet

Preparing the multiple gaussians fit.

parent edce4b3e
......@@ -33,7 +33,7 @@ from collections import OrderedDict
from silx.gui import qt as Qt
from ...io.QSpaceH5 import QSpaceH5
from ...io.FitH5 import FitH5Writer
from ...io.FitH5 import FitH5Writer, FitH5QAxis
from ..widgets.Containers import GroupBox
from ..widgets.Input import StyledLineEdit
from ..widgets.FileChooser import FileChooser
......@@ -181,34 +181,88 @@ class FitWidget(Qt.QDialog):
results = peak_fit(self.__qspaceFile,
fit_type=self.__fitType,
roiIndices=self.__roiIndices)
with FitH5Writer(self.__selectedFile, mode='w') as fitH5:
entry = results.fit_name
process = results.fit_name
entry = results.entry
fitH5.create_entry(entry)
fitH5.create_process(entry, process)
fitH5.set_scan_x(entry, results.sample_x)
fitH5.set_scan_y(entry, results.sample_y)
fitH5.set_status(entry, process, results.status)
fitH5.set_qx(entry, results.q_x)
fitH5.set_qy(entry, results.q_y)
fitH5.set_qz(entry, results.q_z)
for resName, data in results.q_x_results.items():
fitH5.set_qx_result(entry,
process,
resName,
data)
for resName, data in results.q_y_results.items():
fitH5.set_qy_result(entry,
process,
resName,
data)
for resName, data in results.q_z_results.items():
fitH5.set_qz_result(entry,
process,
resName,
data)
processes = results.processes()
for process in processes:
fitH5.create_process(entry, process)
for param in results.params(process):
xresult = results.results(process, param, results.QX_AXIS)
yresult = results.results(process, param, results.QY_AXIS)
zresult = results.results(process, param, results.QZ_AXIS)
xstatus = results.qx_status(process)
ystatus = results.qy_status(process)
zstatus = results.qz_status(process)
fitH5.set_qx_result(entry,
process,
param,
xresult)
fitH5.set_qy_result(entry,
process,
param,
yresult)
fitH5.set_qz_result(entry,
process,
param,
zresult)
fitH5.set_status(entry,
process,
FitH5QAxis.qx_axis,
xstatus)
fitH5.set_status(entry,
process,
FitH5QAxis.qy_axis,
ystatus)
fitH5.set_status(entry,
process,
FitH5QAxis.qz_axis,
zstatus)
# with FitH5Writer(self.__selectedFile, mode='w') as fitH5:
# entry = results.fit_name
# process = results.fit_name
# fitH5.create_entry(entry)
# fitH5.create_process(entry, process)
#
# fitH5.set_scan_x(entry, results.sample_x)
# fitH5.set_scan_y(entry, results.sample_y)
# fitH5.set_status(entry, process, results.status)
# fitH5.set_qx(entry, results.q_x)
# fitH5.set_qy(entry, results.q_y)
# fitH5.set_qz(entry, results.q_z)
#
# for resName, data in results.q_x_results.items():
# fitH5.set_qx_result(entry,
# process,
# resName,
# data)
# for resName, data in results.q_y_results.items():
# fitH5.set_qy_result(entry,
# process,
# resName,
# data)
# for resName, data in results.q_z_results.items():
# fitH5.set_qz_result(entry,
# process,
# resName,
# data)
self.__fitFile = self.__selectedFile
self._setStatus(FitWidget.StatusCompleted)
......
......@@ -153,7 +153,7 @@ class FitButton(EditorMixin, Qt.QWidget):
dialog.selectFile(os.path.join(workdir, itemBasename))
if dialog.exec_():
csvPath = dialog.selectedFiles()[0]
fitItem.fitH5.export_txt(csvPath)
fitItem.fitH5.export_csv(fitItem.fitH5.entries()[0], csvPath)
@H5NodeClassDef('FitGroup',
......
......@@ -36,7 +36,7 @@ from silx.gui import qt as Qt
from kmap.gui.project.XsocsH5Factory import h5NodeToProjectItem
from kmap.gui.widgets.Containers import GroupBox
from kmap.io.FitH5 import FitH5
from kmap.io.FitH5 import FitH5, FitH5QAxis
from ..widgets.XsocsPlot2D import XsocsPlot2D
from kmap.gui.model.TreeView import TreeView
......@@ -218,7 +218,7 @@ class FitView(Qt.QMainWindow):
if processes:
process = processes[0]
if process == 'LeastSq':
if process == 'gaussian':
_initLeastSq(self.__plots, fitH5.filename, entry, process)
elif process == 'Centroid':
_initCentroid(self.__plots, fitH5.filename, entry, process)
......@@ -269,7 +269,7 @@ class FitView(Qt.QMainWindow):
# TODO : refactor
process = self.__process
if process == 'LeastSq':
if process == 'gaussian':
_plotLeastSq(self.__fitPlots, xIdx,
fitH5,
entry, process,
......@@ -319,7 +319,7 @@ def _plotLeastSq(plots, index, fitH5,
xFitQY = fitH5.get_qy(entry)
xFitQZ = fitH5.get_qz(entry)
heights = fitH5.get_result(entry, process, 'height')
heights = fitH5.get_result(entry, process, 'intensity')
positions = fitH5.get_result(entry, process, 'position')
widths = fitH5.get_result(entry, process, 'width')
......@@ -408,17 +408,17 @@ def _initLeastSq(plots, fitH5Name, entry, process):
qApp.processEvents()
plots[0].plotFitResult(fitH5Name, entry, process,
'position', FitH5.qx_axis)
'position', FitH5QAxis.qx_axis)
qApp.processEvents()
plots[1].plotFitResult(fitH5Name, entry, process,
'position', FitH5.qy_axis)
'position', FitH5QAxis.qy_axis)
qApp.processEvents()
plots[2].plotFitResult(fitH5Name, entry, process,
'position', FitH5.qz_axis)
'position', FitH5QAxis.qz_axis)
def _initCentroid(plots, fitH5Name, entry, process):
......@@ -435,13 +435,13 @@ def _initCentroid(plots, fitH5Name, entry, process):
qApp = Qt.qApp
# plots[0].setVisible(True)
qApp.processEvents()
plots[0].plotFitResult(fitH5Name, entry, process, 'position', FitH5.qx_axis)
plots[0].plotFitResult(fitH5Name, entry, process, 'position', FitH5QAxis.qx_axis)
# plots[1].setVisible(True)
qApp.processEvents()
plots[1].plotFitResult(fitH5Name, entry, process, 'position', FitH5.qy_axis)
plots[1].plotFitResult(fitH5Name, entry, process, 'position', FitH5QAxis.qy_axis)
# plots[2].setVisible(True)
qApp.processEvents()
plots[2].plotFitResult(fitH5Name, entry, process, 'position', FitH5.qz_axis)
plots[2].plotFitResult(fitH5Name, entry, process, 'position', FitH5QAxis.qz_axis)
if __name__ == '__main__':
......
......@@ -31,7 +31,7 @@ __date__ = "15/09/2016"
from silx.gui import qt as Qt
from kmap.io.FitH5 import FitH5
from kmap.io.FitH5 import FitH5, FitH5QAxis
from ...widgets.XsocsPlot2D import XsocsPlot2D
......@@ -81,7 +81,7 @@ class DropPlotWidget(XsocsPlot2D):
scan_y = h5f.scan_y(entry)
self.__legend = self.setPlotData(scan_x, scan_y, data)
self.setGraphTitle(result + '/' + FitH5.axis_names[q_axis])
self.setGraphTitle(result + '/' + FitH5QAxis.axis_names[q_axis])
if __name__ == '__main__':
......
......@@ -29,7 +29,6 @@ __authors__ = ["D. Naudet"]
__license__ = "MIT"
__date__ = "01/01/2017"
import numpy as np
from silx.gui import qt as Qt
......@@ -38,7 +37,7 @@ from kmap.gui.model.Model import Model, RootNode
from kmap.gui.project.Hdf5Nodes import H5File
from kmap.gui.model.ModelDef import ModelRoles
from kmap.io.FitH5 import FitH5
from kmap.io.FitH5 import FitH5, FitH5QAxis
from ...widgets.XsocsPlot2D import XsocsPlot2D
from ...project.Hdf5Nodes import H5Base, H5NodeClassDef
......@@ -111,16 +110,16 @@ class FitProcessNode(FitEntryNode):
"""
Node linked to a process group in a FitH5 file.
"""
process = property(lambda self: self.h5Path.split('/')[1])
process = property(lambda self: self.h5Path.lstrip('/').split('/')[1])
def _loadChildren(self):
base = self.h5Path.rstrip('/')
entry = self.entry
process = self.process
children = []
print 'PATH', self.h5Path, self.entry, self.process
with FitH5(self.h5File, mode='r') as h5f:
results = h5f.results(entry, process)
results = h5f.get_result_names(entry, process)
for result in results:
child = FitResultNode(self.h5File, base + '/' + result)
children.append(child)
......@@ -203,11 +202,11 @@ class FitModel(Model):
return super(Model, self).mimeData(indexes)
if index.column() == 1:
q_axis = FitH5.qx_axis
q_axis = FitH5QAxis.qx_axis
elif index.column() == 2:
q_axis = FitH5.qy_axis
q_axis = FitH5QAxis.qy_axis
elif index.column() == 3:
q_axis = FitH5.qz_axis
q_axis = FitH5QAxis.qz_axis
else:
raise ValueError('Unexpected column.')
......
......@@ -38,17 +38,37 @@ from .XsocsH5Base import XsocsH5Base
FitResult = namedtuple('FitResult', ['name', 'qx', 'qy', 'qz'])
class FitH5(XsocsH5Base):
_axis_values = range(3)
qx_axis, qy_axis, qz_axis = _axis_values
class FitH5QAxis(object):
axis_values = range(3)
qx_axis, qy_axis, qz_axis = axis_values
axis_names = ('qx', 'qy', 'qz')
@staticmethod
def axis_name(axis):
return FitH5QAxis.axis_names[axis]
class FitH5(XsocsH5Base):
"""
File containing fit results.
Requirements :
- the number of sample position is defined at entry level : all processes
within the same entry are applied to the same sample points.
- all results arrays within an entry (even if they don't belong to the same
process) have the same size (equal to the number of sample points defined
for that entry)
- all arrays are 1D.
"""
# _axis_values = range(3)
# qx_axis, qy_axis, qz_axis = _axis_values
# axis_names = ('qx', 'qy', 'qz')
title_path = '{entry}/title'
start_time_path = '{entry}/start_time'
end_time_path = '{entry}/end_time'
date_path = '{entry}/{process}/date'
qspace_axis_path = '{entry}/qspace_axis/{axis}'
status_path = '{entry}/{process}/status'
status_path = '{entry}/{process}/status/{axis}'
configuration_path = '{entry}/{process}/configuration'
result_grp_path = '{entry}/{process}/results'
result_path = '{entry}/{process}/results/{result}/{axis}'
......@@ -56,11 +76,20 @@ class FitH5(XsocsH5Base):
scan_y_path = '{entry}/sample/y_pos'
def title(self, entry):
"""
Returns the title for the given entry.
:param entry:
:return:
"""
with self._get_file() as h5_file:
path = entry + '/title'
return h5_file[path][()]
def entries(self):
"""
Return the entry names.
:return:
"""
with self._get_file() as h5_file:
# TODO : this isnt pretty but for some reason the attrs.get() fails
# when there is no attribute NX_class (should return the default
......@@ -71,6 +100,11 @@ class FitH5(XsocsH5Base):
'NX_class'] == 'NXentry')])
def processes(self, entry):
"""
Return the processes names for the given entry.
:param entry:
:return:
"""
with self._get_file() as h5_file:
entry_grp = h5_file[entry]
processes = sorted([key for key in entry_grp
......@@ -79,117 +113,219 @@ class FitH5(XsocsH5Base):
'NX_class'] == 'NXprocess')])
return processes
def results(self, entry, process):
def get_result_names(self, entry, process):
"""
Returns the result names for the given process. Names are ordered
alphabetically.
:param entry:
:param process:
:return:
"""
results_path = self.result_grp_path.format(entry=entry,
process=process)
with self._get_file() as h5_file:
result_grp = h5_file[FitH5.result_grp_path.format(entry=entry,
process=process)]
return sorted(result_grp.keys())
def get_status(self, entry, process):
status_path = FitH5.status_path.format(entry=entry, process=process)
return sorted(h5_file[results_path].keys())
def get_status(self, entry, process, axis):
"""
Returns the fit status for the given entry/process/axis
:param entry:
:param process:
:param axis: FitH5QAxis.qx_axis, FitH5QAxis.qy_axis
or FitH5QAxis.qz_axis
:return:
"""
axis_name = FitH5QAxis.axis_name(axis)
status_path = FitH5.status_path.format(entry=entry,
process=process,
axis=axis_name)
return self._get_array_data(status_path)
def scan_x(self, entry):
"""
Return the sample points coordinates along x for the given entry.
:param entry:
:return:
"""
dset_path = FitH5.scan_x_path.format(entry=entry)
return self._get_array_data(dset_path)
def scan_y(self, entry):
"""
Return the sample points coordinates along y for the given entry.
:param entry:
:return:
"""
dset_path = FitH5.scan_y_path.format(entry=entry)
return self._get_array_data(dset_path)
def get_qx(self, entry):
return self.__get_axis_values(entry, FitH5.qx_axis)
"""
Returns the axis values for qx for the given entry.
:param entry:
:return:
"""
return self.__get_axis_values(entry, FitH5QAxis.qx_axis)
def get_qy(self, entry):
return self.__get_axis_values(entry, FitH5.qy_axis)
"""
Returns the axis values for qy for the given entry.
:param entry:
:return:
"""
return self.__get_axis_values(entry, FitH5QAxis.qy_axis)
def get_qz(self, entry):
return self.__get_axis_values(entry, FitH5.qz_axis)
"""
Returns the axis values for qz for the given entry.
:param entry:
:return:
"""
return self.__get_axis_values(entry, FitH5QAxis.qz_axis)
def __get_axis_values(self, entry, axis):
axis_name = FitH5.axis_names[axis]
"""
Returns the axis values.
:param entry:
:param axis: FitH5QAxis.qx_axis, FitH5QAxis.qy_axis
or FitH5QAxis.qz_axis
:return:
"""
axis_name = FitH5QAxis.axis_name(axis)
return self._get_array_data(FitH5.qspace_axis_path.format(
entry=entry, axis=axis_name))
def result(self, entry, process, result):
def get_axis_result(self, entry, process, result, axis):
"""
Returns the results for the given entry/process/result name/axis.
:param entry:
:param process:
:param result:
:param axis: FitH5QAxis.qx_axis, FitH5QAxis.qy_axis or
FitH5QAxis.qz_axis
:return:
"""
assert axis in FitH5QAxis.axis_values
axis_name = FitH5QAxis.axis_name(axis)
result_path = FitH5.result_path.format(entry=entry,
process=process,
result=result)
return self._get_array_data(result_path)
def get_axis_result(self, entry, process, name, q_axis):
assert q_axis in FitH5._axis_values
axis_name = self.axis_names[q_axis]
result_path = FitH5.result_path.format(entry=entry,
process=process,
result=name,
result=result,
axis=axis_name)
return self._get_array_data(result_path)
def get_qx_result(self, entry, process, result):
return self.get_axis_result(entry, process, result, FitH5.qx_axis)
"""
Returns the results (qx) for the given entry/process/result name.
:param entry:
:param process:
:param result:
:return:
"""
return self.get_axis_result(entry, process, result, FitH5QAxis.qx_axis)
def get_qy_result(self, entry, process, result):
return self.get_axis_result(entry, process, result, FitH5.qy_axis)
"""
Returns the results (qy) for the given entry/process/result name.
:param entry:
:param process:
:param result:
:return:
"""
return self.get_axis_result(entry, process, result, FitH5QAxis.qy_axis)
def get_qz_result(self, entry, process, result):
return self.get_axis_result(entry, process, result, FitH5.qz_axis)
"""
Returns the results (qz) for the given entry/process/result name.
:param entry:
:param process:
:param result:
:return:
"""
return self.get_axis_result(entry, process, result, FitH5QAxis.qz_axis)
def get_result(self, entry, process, result):
"""
Returns the results values (qx, qy, qz) for
the given entry/process/result name.
:param entry:
:param process:
:param result:
:return: a FitResult instance.
"""
with self:
results = {}
for axis in FitH5._axis_values:
results[FitH5.axis_names[axis]] = \
for axis in FitH5QAxis.axis_values:
results[FitH5QAxis.axis_name(axis)] = \
self.get_axis_result(entry, process, result, axis)
return FitResult(name=result, **results)
def export_txt(self, filename):
# TODO : change this when multiple entries/processes are supported
entry = self.entries()[0]
if not entry:
raise ValueError('No entries.')
process = self.processes(entry)[0]
if not process:
raise ValueError('No processed for entry {0}.'.format(entry))
def get_n_points(self, entry):
"""
Returns the number of sample positions for this entry.
:param entry:
:return:
"""
dset_path = FitH5.scan_x_path.format(entry=entry)
shape = self._get_array_data(dset_path, shape=True)
return shape[0]
def export_csv(self, entry, filename):
"""
Exports an entry results as csv.
:param entry:
:param filename:
:return:
"""
x, y = self.scan_x(entry), self.scan_y(entry)
processes = self.processes(entry)
if len(processes) == 0:
raise ValueError('No process found for entry {0}.'.format(entry))
# with open(filename, 'w+') as res_f:
with self:
with open(filename, 'w+') as res_f:
res_f.write('X Y '
'height_x center_x width_x '
'height_y center_y width_y '
'height_z center_z width_z '
'|q| status\n')
heights = self.get_result(entry, process, 'height')
positions = self.get_result(entry, process, 'position')
widths = self.get_result(entry, process, 'width')
x_height = heights.qx
x_center = positions.qx
x_width = widths.qx
y_height = heights.qy
y_center = positions.qy
y_width = widths.qy
z_height = heights.qz
z_center = positions.qz
z_width = widths.qz
q = np.sqrt(x_center ** 2 +
y_center ** 2 +
z_center ** 2)
status = self.get_status(entry, process)
x, y = self.scan_x(entry), self.scan_y(entry)
for i, s in enumerate(status):
r = [x[i], y[i],
x_height[i], x_center[i], x_width[i],
y_height[i], y_center[i], y_width[i],
z_height[i], z_center[i], z_width[i],
q[i], s]
res_str = '{0}\n'.format(' '.join(str(e) for e in r))
res_f.write(res_str)
header_process = ['_', 'process:']
header_list = ['X', 'Y']
for process in processes:
result_names = self.get_result_names(entry, process)
for axis in FitH5QAxis.axis_names:
for result_name in result_names:
header_process.append(process)
header_list.append(result_name + '_' + axis)
header_process.append(process)
header_list.append('status_' + axis)
header = ' '.join(header_process) + '\n' + ' '.join(header_list)
results = np.zeros((len(x), len(header_list)))
results[:, 0] = x
results[:, 1] = y
col_idx = 2
for process in processes:
result_names = self.get_result_names(entry, process)
for axis in FitH5QAxis.axis_values:
for result_name in result_names:
result = self.get_axis_result(entry,
process,
result_name,
axis)
results[:, col_idx] = result
col_idx += 1
results[:, col_idx] = self.get_status(entry,
process,
axis)
col_idx += 1
np.savetxt(filename,
results,
fmt='%.10g',
header=header,
comments='')
class FitH5Writer(FitH5):
......@@ -234,13 +370,16 @@ class FitH5Writer(FitH5):
def set_title(self, entry, title):
self._set_scalar_data(FitH5.title_path.format(entry), title)
def set_status(self, entry, process, data):
status_path = FitH5.status_path.format(entry=entry, process=process)
def set_status(self, entry, process, axis, data):
axis_name = FitH5QAxis.axis_name(axis)
status_path = FitH5.status_path.format(entry=entry,
process=process,
axis=axis_name)