grainPlotWidget.py 14.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# coding: utf-8
# /*##########################################################################
#
# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/


__authors__ = ["J. Garriga"]
__license__ = "MIT"
29
__date__ = "05/08/2021"
30
31
32
33
34
35
36
37

from matplotlib.colors import hsv_to_rgb
import numpy

from silx.gui import qt
from silx.gui.colors import Colormap
from silx.gui.plot import Plot2D
from silx.image.marchingsquares import find_contours
38
from silx.math.medianfilter import medfilt2d
39
from silx.utils.enum import Enum as _Enum
40
from silx.io.dictdump import dicttonx
41

42
import darfix
43
44
from .operationThread import OperationThread

45
46
47

class Method(_Enum):
    """
48
    Different maps to show
49
50
51
52
53
    """
    COM = "Center of mass"
    FWHM = "FWHM"
    SKEWNESS = "Skewness"
    KURTOSIS = "Kurtosis"
54
    ORI_DIST = "Orientation distribution"
55
    MOSAICITY = "Mosaicity"
56
57


58
class GrainPlotWidget(qt.QMainWindow):
59
    """
60
    Widget to show a series of maps for the analysis of the data.
61
62
63
64
65
66
    """
    sigComputed = qt.Signal()

    def __init__(self, parent=None):
        qt.QWidget.__init__(self, parent)

67
        self._methodCB = qt.QComboBox()
68
        self._methodCB.addItems(Method.values())
69
70
        for i in range(len(Method)):
            self._methodCB.model().item(i).setEnabled(False)
71
72
73
74
        self._methodCB.currentTextChanged.connect(self._updatePlot)
        self._plotWidget = qt.QWidget()
        plotsLayout = qt.QHBoxLayout()
        self._plotWidget.setLayout(plotsLayout)
75
        self._contoursPlot = Plot2D(parent=self)
76
        widget = qt.QWidget(parent=self)
77
78
79
        layout = qt.QVBoxLayout()
        self._levelsWidget = qt.QWidget()
        levelsLayout = qt.QGridLayout()
80
        levelsLabel = qt.QLabel("Number of levels:")
81
82
83
84
        self._levelsLE = qt.QLineEdit("20")
        self._levelsLE.setToolTip("Number of levels to use when finding the contours")
        self._levelsLE.setValidator(qt.QIntValidator())
        self._computeContoursB = qt.QPushButton("Compute")
85
86
        self._centerDataCheckbox = qt.QCheckBox("Center angle values")
        self._centerDataCheckbox.stateChanged.connect(self._checkboxStateChanged)
87
88
        levelsLayout.addWidget(levelsLabel, 0, 0, 1, 1)
        levelsLayout.addWidget(self._levelsLE, 0, 1, 1, 1)
89
90
91
        levelsLayout.addWidget(self._centerDataCheckbox, 0, 2, 1, 1)
        levelsLayout.addWidget(self._computeContoursB, 1, 2, 1, 1)
        levelsLayout.addWidget(self._contoursPlot, 2, 0, 1, 3)
92
        self._levelsWidget.setLayout(levelsLayout)
93
        self._mosaicityPlot = Plot2D(parent=self)
94
        self._exportButton = qt.QPushButton("Export")
95
        self._exportButton.clicked.connect(self.exportMaps)
96
97
        layout.addWidget(self._methodCB)
        layout.addWidget(self._levelsWidget)
98
        layout.addWidget(self._plotWidget)
99
        layout.addWidget(self._mosaicityPlot)
100
        layout.addWidget(self._exportButton)
101
        self._plotWidget.hide()
102
        self._mosaicityPlot.hide()
103
        self._mosaicityPlot.getColorBarWidget().hide()
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        widget.setLayout(layout)
        widget.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum)
        self.setCentralWidget(widget)

    def setDataset(self, dataset, indices=None, bg_indices=None, bg_dataset=None):
        """
        Dataset setter.

        :param Dataset dataset: dataset
        """
        self.dataset = dataset
        self.indices = indices
        self.bg_indices = bg_indices
        self.bg_dataset = bg_dataset
118
119
        for i in range(len(Method)):
            self._methodCB.model().item(i).setEnabled(False)
120
        scale = 100
121
122
123
124
125
126
127
        if self.dataset.transformation:
            px = self.dataset.transformation[0][0][0]
            py = self.dataset.transformation[1][0][0]
            xscale = (self.dataset.transformation[0][-1][-1] - px) / len(self.dataset.transformation[0][0])
            yscale = (self.dataset.transformation[1][-1][-1] - py) / len(self.dataset.transformation[1][0])
            self.origin = (px, py)
            self.scale = (xscale, yscale)
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        if self.dataset.dims.ndim > 1:
            self.ori_dist, self.hsv_key = self.dataset.compute_mosaicity_colorkey()
            xdim = self.dataset.dims.get(1)
            ydim = self.dataset.dims.get(0)
            xscale = (xdim.unique_values[-1] - xdim.unique_values[0]) / (xdim.size - 1)
            yscale = (ydim.unique_values[-1] - ydim.unique_values[0]) / (ydim.size - 1)

            self._contoursPlot.addImage(hsv_to_rgb(self.hsv_key), xlabel=xdim.name,
                                        ylabel=ydim.name, scale=(xscale / scale, yscale / scale))
            self._contoursPlot.getColorBarWidget().hide()
            self._curvesColormap = Colormap(name='temperature',
                                            vmin=numpy.min(self.ori_dist),
                                            vmax=numpy.max(self.ori_dist))
            self._computeContoursB.clicked.connect(self._computeContours)
            self._methodCB.model().item(4).setEnabled(True)
            self._methodCB.setCurrentIndex(4)
144
        self._thread = OperationThread(self, self.dataset.apply_moments)
145
146
147
        self._thread.setArgs(self.indices)
        self._thread.finished.connect(self._updateData)
        self._thread.start()
148
149
        for i in reversed(range(self._plotWidget.layout().count())):
            self._plotWidget.layout().itemAt(i).widget().setParent(None)
150
151

        self._plots = []
152
        for axis, dim in self.dataset.dims:
153
            self._plots += [Plot2D(parent=self)]
154
            self._plots[-1].setGraphTitle(dim.name)
155
            self._plots[-1].setDefaultColormap(Colormap(name='viridis'))
156
            self._plotWidget.layout().addWidget(self._plots[-1])
157

158
159
160
161
162
163
164
    def _updateData(self):
        """
        Updates the plots with the data computed in the thread
        """
        self._thread.finished.disconnect(self._updateData)
        if self._thread.data is not None:
            self._moments = self._thread.data
165
166
167
            self._updatePlot(self._methodCB.currentText())
            rg = len(Method) if self.dataset.dims.ndim > 1 else 4
            for i in range(rg):
168
                self._methodCB.model().item(i).setEnabled(True)
169
            self._methodCB.setCurrentIndex(0)
170
171
        else:
            print("\nComputation aborted")
172

173
174
175
176
177
    def _checkboxStateChanged(self, state):
        """
        Update widgets linked to the checkbox state
        """
        scale = 100
178
179
        xdim = self.dataset.dims.get(1)
        ydim = self.dataset.dims.get(0)
180
181
182
183
184
        xsize = xdim.size - 1
        ysize = ydim.size - 1
        xscale = (xdim.unique_values[-1] - xdim.unique_values[0]) / xsize
        yscale = (ydim.unique_values[-1] - ydim.unique_values[0]) / ysize
        origin = (- xscale * xsize / 2, - yscale * ysize / 2) if state else (0., 0.)
185
        self._contoursPlot.addImage(hsv_to_rgb(self.hsv_key), xlabel=xdim.name, ylabel=ydim.name,
186
187
188
189
                                    origin=origin, scale=(xscale / scale, yscale / scale))

        self._contoursPlot.remove(kind='curve')

190
191
192
193
194
195
196
197
198
199
200
    def _computeContours(self):
        self._contoursPlot.remove(kind='curve')

        if self.ori_dist is not None:
            polygons = []
            levels = []
            for i in numpy.linspace(numpy.min(self.ori_dist), numpy.max(self.ori_dist), int(self._levelsLE.text())):
                polygons.append(find_contours(self.ori_dist, i))
                levels.append(i)

            colors = self._curvesColormap.applyToData(levels)
201
202
            xdim = self.dataset.dims.get(1)
            ydim = self.dataset.dims.get(0)
203
            self._curves = []
204
205
206
207
208
209
            for ipolygon, polygon in enumerate(polygons):
                # iso contours
                for icontour, contour in enumerate(polygon):
                    if len(contour) == 0:
                        continue
                    # isClosed = numpy.allclose(contour[0], contour[-1])
210
211
                    x = contour[:, 1]
                    y = contour[:, 0]
212
213
                    x *= (xdim.unique_values[-1] - xdim.unique_values[0]) / (xdim.size - 1)
                    y *= (ydim.unique_values[-1] - ydim.unique_values[0]) / (ydim.size - 1)
214
215

                    if self._centerDataCheckbox.isChecked():
216
217
218
219
                        xcenter = (xdim.unique_values[-1] - xdim.unique_values[0]) / 2
                        x -= xcenter
                        ycenter = (ydim.unique_values[-1] - ydim.unique_values[0]) / 2
                        y -= ycenter
220
                    self._curves.append((x, y, colors[ipolygon]))
221
222
223
224
225
                    legend = "custom-polygon-%d" % icontour * (ipolygon + 1)
                    self._contoursPlot.addCurve(x=x, y=y, linestyle="-", linewidth=2.0,
                                                legend=legend, resetzoom=False,
                                                color=colors[ipolygon])

226
227
    def _computeMosaicity(self):

228
229
        norms0 = (self._moments[0][0] - numpy.min(self._moments[0][0])) / numpy.ptp(self._moments[0][0])
        norms1 = (self._moments[1][0] - numpy.min(self._moments[1][0])) / numpy.ptp(self._moments[1][0])
230

231
        mosaicity = numpy.stack((norms0, norms1, numpy.ones(self._moments[0].shape[1:])), axis=2)
232
        return mosaicity
233

234
235
236
    def _updatePlot(self, method):
        method = Method(method)
        self._levelsWidget.hide()
237
        self._mosaicityPlot.hide()
238
        if method == Method.ORI_DIST:
239
240
            self._levelsWidget.show()
            self._plotWidget.hide()
241
242
243
        elif method == Method.FWHM:
            self._plotWidget.show()
            for i, plot in enumerate(self._plots):
244
                plot.addImage(darfix.config.FWHM_VAL * self._moments[i][1])
245
246
        elif method == Method.COM:
            self._plotWidget.show()
247
248
            if self.dataset.transformation is not None:
                for i, plot in enumerate(self._plots):
249
                    plot.addImage(self._moments[i][0], origin=self.origin, scale=self.scale, xlabel='µm', ylabel='µm')
250
251
            else:
                for i, plot in enumerate(self._plots):
252
                    plot.addImage(self._moments[i][0], xlabel='pixels', ylabel='pixels')
253
254
255
        elif method == Method.SKEWNESS:
            self._plotWidget.show()
            for i, plot in enumerate(self._plots):
256
                plot.addImage(self._moments[i][2])
257
258
259
        elif method == Method.KURTOSIS:
            self._plotWidget.show()
            for i, plot in enumerate(self._plots):
260
                plot.addImage(self._moments[i][3])
261
262
        elif method == Method.MOSAICITY:
            self._plotWidget.hide()
263
264
265
266
267
268
269
            if self.dataset.transformation:
                self._mosaicityPlot.addImage(hsv_to_rgb(self._computeMosaicity()),
                                             origin=self.origin, scale=self.scale,
                                             xlabel='µm', ylabel='µm')
            else:
                self._mosaicityPlot.addImage(hsv_to_rgb(self._computeMosaicity()))

270
            self._mosaicityPlot.show()
271
272

    def _opticolor(self, img, minc, maxc):
273
        img = img.copy()
274
275
        Cnn = img[~numpy.isnan(img)]
        sortC = sorted(Cnn)
276
277
        Imin = sortC[int(numpy.floor(len(sortC) * minc))]
        Imax = sortC[int(numpy.floor(len(sortC) * maxc))]
278
279
280
        img[img > Imax] = Imax
        img[img < Imin] = Imin

281
        return medfilt2d(img)
282

283
    def exportMaps(self):
284
285
286
        """
        Loads the file from a FileDialog.
        """
287

288
        if self.dataset and self.dataset.dims.ndim > 1:
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
            nx = {
                "entry": {
                    "data": {
                        ">" + Method.MOSAICITY.name: "../maps/" + Method.MOSAICITY.name,
                        "@signal": Method.MOSAICITY.name,
                        "@NX_class": "NXdata"
                    },
                    "maps": {
                        Method.ORI_DIST.name: self.ori_dist,
                        Method.MOSAICITY.name: hsv_to_rgb(self._computeMosaicity()),
                        Method.MOSAICITY.name + "@interpretation": "rgba-image",
                        "@NX_class": "NXcollection"
                    },
                    "@NX_class": "NXentry",
                    "@default": "data",
                },
                "@NX_class": "NXroot",
                "@default": "entry"
307
308
            }

309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
            for axis, dim in self.dataset.dims:
                nx["entry"]["maps"].update(
                    {
                        Method.COM.name: self._moments[axis][0],
                        Method.FWHM.name: self._moments[axis][1],
                        Method.SKEWNESS.name: self._moments[axis][2],
                        Method.KURTOSIS.name: self._moments[axis][3]
                    }
                )

        else:
            nx = {
                "entry": {
                    "data": {
                        ">" + Method.COM.name: "../maps/" + Method.COM.name,
                        "@signal": Method.COM.name,
                        "@NX_class": "NXdata"
                    },
                    "maps": {
                        Method.COM.name: self._moments[0][0],
                        Method.FWHM.name: self._moments[0][1],
                        Method.SKEWNESS.name: self._moments[0][2],
                        Method.KURTOSIS.name: self._moments[0][3],
                        "@NX_class": "NXcollection"
                    },
                    "@NX_class": "NXentry",
                    "@default": "data",
                },
                "@NX_class": "NXroot",
                "@default": "entry"
            }
340

341
342
343
344
345
346
        fileDialog = qt.QFileDialog()

        fileDialog.setFileMode(fileDialog.AnyFile)
        fileDialog.setAcceptMode(fileDialog.AcceptSave)
        fileDialog.setOption(fileDialog.DontUseNativeDialog)
        fileDialog.setDefaultSuffix(".h5")
347
        if fileDialog.exec_():
348
            dicttonx(nx, fileDialog.selectedFiles()[0])