grainPlotWidget.py 13.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# coding: utf-8
# /*##########################################################################
#
# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/


__authors__ = ["J. Garriga"]
__license__ = "MIT"
29
__date__ = "05/08/2021"
30
31
32
33
34
35
36
37

from matplotlib.colors import hsv_to_rgb
import numpy

from silx.gui import qt
from silx.gui.colors import Colormap
from silx.gui.plot import Plot2D
from silx.image.marchingsquares import find_contours
38
from silx.math.medianfilter import medfilt2d
39
from silx.utils.enum import Enum as _Enum
40
from silx.io.dictdump import dicttonx
41

42
import darfix
43
44
from .operationThread import OperationThread

45
46
47

class Method(_Enum):
    """
48
    Different maps to show
49
50
51
52
53
    """
    COM = "Center of mass"
    FWHM = "FWHM"
    SKEWNESS = "Skewness"
    KURTOSIS = "Kurtosis"
54
    ORI_DIST = "Orientation distribution"
55
    MOSAICITY = "Mosaicity"
56
57


58
class GrainPlotWidget(qt.QMainWindow):
59
    """
60
    Widget to show a series of maps for the analysis of the data.
61
62
63
64
65
66
    """
    sigComputed = qt.Signal()

    def __init__(self, parent=None):
        qt.QWidget.__init__(self, parent)

67
        self._methodCB = qt.QComboBox()
68
        self._methodCB.addItems(Method.values())
69
70
        for i in range(len(Method)):
            self._methodCB.model().item(i).setEnabled(False)
71
72
73
74
        self._methodCB.currentTextChanged.connect(self._updatePlot)
        self._plotWidget = qt.QWidget()
        plotsLayout = qt.QHBoxLayout()
        self._plotWidget.setLayout(plotsLayout)
75
        self._contoursPlot = Plot2D(parent=self)
76
        widget = qt.QWidget(parent=self)
77
78
79
        layout = qt.QVBoxLayout()
        self._levelsWidget = qt.QWidget()
        levelsLayout = qt.QGridLayout()
80
        levelsLabel = qt.QLabel("Number of levels:")
81
82
83
84
        self._levelsLE = qt.QLineEdit("20")
        self._levelsLE.setToolTip("Number of levels to use when finding the contours")
        self._levelsLE.setValidator(qt.QIntValidator())
        self._computeContoursB = qt.QPushButton("Compute")
85
86
        self._centerDataCheckbox = qt.QCheckBox("Center angle values")
        self._centerDataCheckbox.stateChanged.connect(self._checkboxStateChanged)
87
88
        levelsLayout.addWidget(levelsLabel, 0, 0, 1, 1)
        levelsLayout.addWidget(self._levelsLE, 0, 1, 1, 1)
89
90
91
        levelsLayout.addWidget(self._centerDataCheckbox, 0, 2, 1, 1)
        levelsLayout.addWidget(self._computeContoursB, 1, 2, 1, 1)
        levelsLayout.addWidget(self._contoursPlot, 2, 0, 1, 3)
92
        self._levelsWidget.setLayout(levelsLayout)
93
        self._mosaicityPlot = Plot2D(parent=self)
94
95
        self._exportButton = qt.QPushButton("Export")
        self._exportButton.clicked.connect(self._saveMaps)
96
97
        layout.addWidget(self._methodCB)
        layout.addWidget(self._levelsWidget)
98
        layout.addWidget(self._plotWidget)
99
        layout.addWidget(self._mosaicityPlot)
100
        layout.addWidget(self._exportButton)
101
        self._plotWidget.hide()
102
        self._mosaicityPlot.hide()
103
        self._mosaicityPlot.getColorBarWidget().hide()
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        widget.setLayout(layout)
        widget.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum)
        self.setCentralWidget(widget)

    def setDataset(self, dataset, indices=None, bg_indices=None, bg_dataset=None):
        """
        Dataset setter.

        :param Dataset dataset: dataset
        """
        self.dataset = dataset
        self.indices = indices
        self.bg_indices = bg_indices
        self.bg_dataset = bg_dataset
118
119
        for i in range(len(Method)):
            self._methodCB.model().item(i).setEnabled(False)
120
        scale = 100
121
122
123
124
125
126
127
        if self.dataset.transformation:
            px = self.dataset.transformation[0][0][0]
            py = self.dataset.transformation[1][0][0]
            xscale = (self.dataset.transformation[0][-1][-1] - px) / len(self.dataset.transformation[0][0])
            yscale = (self.dataset.transformation[1][-1][-1] - py) / len(self.dataset.transformation[1][0])
            self.origin = (px, py)
            self.scale = (xscale, yscale)
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        if self.dataset.dims.ndim > 1:
            self.ori_dist, self.hsv_key = self.dataset.compute_mosaicity_colorkey()
            xdim = self.dataset.dims.get(1)
            ydim = self.dataset.dims.get(0)
            xscale = (xdim.unique_values[-1] - xdim.unique_values[0]) / (xdim.size - 1)
            yscale = (ydim.unique_values[-1] - ydim.unique_values[0]) / (ydim.size - 1)

            self._contoursPlot.addImage(hsv_to_rgb(self.hsv_key), xlabel=xdim.name,
                                        ylabel=ydim.name, scale=(xscale / scale, yscale / scale))
            self._contoursPlot.getColorBarWidget().hide()
            self._curvesColormap = Colormap(name='temperature',
                                            vmin=numpy.min(self.ori_dist),
                                            vmax=numpy.max(self.ori_dist))
            self._computeContoursB.clicked.connect(self._computeContours)
            self._methodCB.model().item(4).setEnabled(True)
            self._methodCB.setCurrentIndex(4)
144
        self._thread = OperationThread(self, self.dataset.apply_moments)
145
146
147
        self._thread.setArgs(self.indices)
        self._thread.finished.connect(self._updateData)
        self._thread.start()
148
149
        for i in reversed(range(self._plotWidget.layout().count())):
            self._plotWidget.layout().itemAt(i).widget().setParent(None)
150
151

        self._plots = []
152
        for axis, dim in self.dataset.dims:
153
            self._plots += [Plot2D(parent=self)]
154
            self._plots[-1].setGraphTitle(dim.name)
155
            self._plots[-1].setDefaultColormap(Colormap(name='viridis'))
156
            self._plotWidget.layout().addWidget(self._plots[-1])
157

158
159
160
161
162
163
164
    def _updateData(self):
        """
        Updates the plots with the data computed in the thread
        """
        self._thread.finished.disconnect(self._updateData)
        if self._thread.data is not None:
            self._moments = self._thread.data
165
166
167
            self._updatePlot(self._methodCB.currentText())
            rg = len(Method) if self.dataset.dims.ndim > 1 else 4
            for i in range(rg):
168
                self._methodCB.model().item(i).setEnabled(True)
169
            self._methodCB.setCurrentIndex(0)
170
171
        else:
            print("\nComputation aborted")
172

173
174
175
176
177
    def _checkboxStateChanged(self, state):
        """
        Update widgets linked to the checkbox state
        """
        scale = 100
178
179
        xdim = self.dataset.dims.get(1)
        ydim = self.dataset.dims.get(0)
180
181
182
183
184
        xsize = xdim.size - 1
        ysize = ydim.size - 1
        xscale = (xdim.unique_values[-1] - xdim.unique_values[0]) / xsize
        yscale = (ydim.unique_values[-1] - ydim.unique_values[0]) / ysize
        origin = (- xscale * xsize / 2, - yscale * ysize / 2) if state else (0., 0.)
185
        self._contoursPlot.addImage(hsv_to_rgb(self.hsv_key), xlabel=xdim.name, ylabel=ydim.name,
186
187
188
189
                                    origin=origin, scale=(xscale / scale, yscale / scale))

        self._contoursPlot.remove(kind='curve')

190
191
192
193
194
195
196
197
198
199
200
    def _computeContours(self):
        self._contoursPlot.remove(kind='curve')

        if self.ori_dist is not None:
            polygons = []
            levels = []
            for i in numpy.linspace(numpy.min(self.ori_dist), numpy.max(self.ori_dist), int(self._levelsLE.text())):
                polygons.append(find_contours(self.ori_dist, i))
                levels.append(i)

            colors = self._curvesColormap.applyToData(levels)
201
202
            xdim = self.dataset.dims.get(1)
            ydim = self.dataset.dims.get(0)
203
            self._curves = []
204
205
206
207
208
209
            for ipolygon, polygon in enumerate(polygons):
                # iso contours
                for icontour, contour in enumerate(polygon):
                    if len(contour) == 0:
                        continue
                    # isClosed = numpy.allclose(contour[0], contour[-1])
210
211
                    x = contour[:, 1]
                    y = contour[:, 0]
212
213
                    x *= (xdim.unique_values[-1] - xdim.unique_values[0]) / (xdim.size - 1)
                    y *= (ydim.unique_values[-1] - ydim.unique_values[0]) / (ydim.size - 1)
214
215

                    if self._centerDataCheckbox.isChecked():
216
217
218
219
                        xcenter = (xdim.unique_values[-1] - xdim.unique_values[0]) / 2
                        x -= xcenter
                        ycenter = (ydim.unique_values[-1] - ydim.unique_values[0]) / 2
                        y -= ycenter
220
                    self._curves.append((x, y, colors[ipolygon]))
221
222
223
224
225
                    legend = "custom-polygon-%d" % icontour * (ipolygon + 1)
                    self._contoursPlot.addCurve(x=x, y=y, linestyle="-", linewidth=2.0,
                                                legend=legend, resetzoom=False,
                                                color=colors[ipolygon])

226
227
    def _computeMosaicity(self):

228
229
        norms0 = (self._moments[0][0] - numpy.min(self._moments[0][0])) / numpy.ptp(self._moments[0][0])
        norms1 = (self._moments[1][0] - numpy.min(self._moments[1][0])) / numpy.ptp(self._moments[1][0])
230

231
        mosaicity = numpy.stack((norms0, norms1, numpy.ones(self._moments[0].shape[1:])), axis=2)
232
        return mosaicity
233

234
235
236
    def _updatePlot(self, method):
        method = Method(method)
        self._levelsWidget.hide()
237
        self._mosaicityPlot.hide()
238
        if method == Method.ORI_DIST:
239
240
            self._levelsWidget.show()
            self._plotWidget.hide()
241
242
243
        elif method == Method.FWHM:
            self._plotWidget.show()
            for i, plot in enumerate(self._plots):
244
                plot.addImage(darfix.config.FWHM_VAL * self._moments[i][1])
245
246
        elif method == Method.COM:
            self._plotWidget.show()
247
248
            if self.dataset.transformation is not None:
                for i, plot in enumerate(self._plots):
249
                    plot.addImage(self._moments[i][0], origin=self.origin, scale=self.scale, xlabel='µm', ylabel='µm')
250
251
            else:
                for i, plot in enumerate(self._plots):
252
                    plot.addImage(self._moments[i][0], xlabel='pixels', ylabel='pixels')
253
254
255
        elif method == Method.SKEWNESS:
            self._plotWidget.show()
            for i, plot in enumerate(self._plots):
256
                plot.addImage(self._moments[i][2])
257
258
259
        elif method == Method.KURTOSIS:
            self._plotWidget.show()
            for i, plot in enumerate(self._plots):
260
                plot.addImage(self._moments[i][3])
261
262
        elif method == Method.MOSAICITY:
            self._plotWidget.hide()
263
264
265
266
267
268
269
            if self.dataset.transformation:
                self._mosaicityPlot.addImage(hsv_to_rgb(self._computeMosaicity()),
                                             origin=self.origin, scale=self.scale,
                                             xlabel='µm', ylabel='µm')
            else:
                self._mosaicityPlot.addImage(hsv_to_rgb(self._computeMosaicity()))

270
            self._mosaicityPlot.show()
271
272

    def _opticolor(self, img, minc, maxc):
273
        img = img.copy()
274
275
        Cnn = img[~numpy.isnan(img)]
        sortC = sorted(Cnn)
276
277
        Imin = sortC[int(numpy.floor(len(sortC) * minc))]
        Imax = sortC[int(numpy.floor(len(sortC) * maxc))]
278
279
280
        img[img > Imax] = Imax
        img[img < Imin] = Imin

281
        return medfilt2d(img)
282
283
284
285
286

    def _saveMaps(self):
        """
        Loads the file from a FileDialog.
        """
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306

        if self.dataset.dims.ndim > 1:
            nx = {
                "entry": {
                    "data": {
                        ">" + Method.MOSAICITY.name: "../maps/" + Method.MOSAICITY.name,
                        "@signal": Method.MOSAICITY.name,
                        "@NX_class": "NXdata"
                    },
                    "maps": {
                        Method.ORI_DIST.name: self.ori_dist,
                        Method.MOSAICITY.name: hsv_to_rgb(self._computeMosaicity()),
                        Method.MOSAICITY.name + "@interpretation": "rgba-image",
                        "@NX_class": "NXcollection"
                    },
                    "@NX_class": "NXentry",
                    "@default": "data",
                },
                "@NX_class": "NXroot",
                "@default": "entry"
307
308
            }

309
310
311
312
313
314
315
316
317
        for axis, dim in self.dataset.dims:
            nx["entry"]["maps"].update(
                {
                    Method.COM.name: self._moments[axis][0],
                    Method.FWHM.name: self._moments[axis][1],
                    Method.SKEWNESS.name: self._moments[axis][2],
                    Method.KURTOSIS.name: self._moments[axis][3]
                }
            )
318

319
320
321
322
323
324
        fileDialog = qt.QFileDialog()

        fileDialog.setFileMode(fileDialog.AnyFile)
        fileDialog.setAcceptMode(fileDialog.AcceptSave)
        fileDialog.setOption(fileDialog.DontUseNativeDialog)
        fileDialog.setDefaultSuffix(".h5")
325
        if fileDialog.exec_():
326
            dicttonx(nx, fileDialog.selectedFiles()[0])