Commit 66516f75 authored by Julia Garriga Ferrer's avatar Julia Garriga Ferrer
Browse files

Merge branch 'max_comp_bss' into 'master'

BSS and PCA Max Components

See merge request julia.garriga/darfix!40
parents acba390c e189ac9b
Pipeline #20470 passed with stage
in 3 minutes
......@@ -26,7 +26,7 @@
__authors__ = ["J. Garriga"]
__license__ = "MIT"
__date__ = "15/11/2019"
__date__ = "28/01/2020"
import cv2
......@@ -65,8 +65,6 @@ class BSS():
self._data_frames = data_frames
self.PCA_data = None
self.X = self._X_from_images(self.data_frames)
"""input frames (flatten)"""
......@@ -79,7 +77,7 @@ class BSS():
shape = data_frames.shape
return data_frames.reshape((shape[0], shape[1] * shape[2]))
def PCA(self, num_components=None):
def PCA(self, num_components=None, max_components=None):
"""
Applies Principal component analysis.
......@@ -93,9 +91,16 @@ class BSS():
log = "Computing PCA with {} components".format(num_components)
_logger.info(log)
if not num_components:
if not self.PCA_data:
self.PCA_data = cv2.PCACompute2(self.X, numpy.mean(self.X, axis=0).reshape(1, -1))
return self.PCA_data
if not max_components:
PCA_data = cv2.PCACompute2(self.X, numpy.mean(self.X,
axis=0).reshape(1, -1))
else:
if max_components < 1:
max_components *= self.X.shape[0]
PCA_data = cv2.PCACompute2(self.X, numpy.mean(self.X,
axis=0).reshape(1, -1),
maxComponents=int(max_components))
return PCA_data
model = PCA(n_components=num_components)
if self.X is not None:
W = model.fit_transform(self.X)
......@@ -381,8 +386,7 @@ class BSS():
W0 = G0.transpose().copy() # Make array C-contiguous
H0 = F0.transpose()
nmf = NMF(n_components=num_components, init='custom')
nmf = NMF(n_components=min(num_components, S.shape[0]), init='custom')
W = nmf.fit_transform(X.transpose(), W=W0, H=H0)
H = nmf.components_
......
......@@ -26,7 +26,7 @@
__authors__ = ["J. Garriga"]
__license__ = "MIT"
__date__ = "04/12/2019"
__date__ = "30/01/2020"
import numpy
......@@ -38,7 +38,7 @@ from darfix.core.blindSourceSeparation import BSS
from .operationThread import OperationThread
class PCAWidget(qt.QWidget):
class PCAWidget(qt.QMainWindow):
"""
Widget to apply PCA to a set of images and plot the eigenvalues found.
"""
......@@ -47,47 +47,59 @@ class PCAWidget(qt.QWidget):
def __init__(self, parent=None):
qt.QWidget.__init__(self, parent)
self.setLayout(qt.QVBoxLayout())
self._plot = Plot1D()
self.layout().addWidget(self._plot)
self._plot.setDataMargins(0.05, 0.05, 0.05, 0.05)
maxNComponentsLabel = qt.QLabel("Max number of components:")
self.maxNumComp = qt.QLineEdit("")
self.maxNumComp.setToolTip("Maximum number of components to compute")
self.maxNumComp.setValidator(qt.QDoubleValidator())
self.computeB = qt.QPushButton("Compute")
widget = qt.QWidget(parent=self)
layout = qt.QGridLayout()
layout.addWidget(maxNComponentsLabel, 0, 0, 1, 1)
layout.addWidget(self.maxNumComp, 0, 1, 1, 1)
layout.addWidget(self.computeB, 0, 2, 1, 1)
layout.addWidget(self._plot, 1, 0, 1, 3)
widget.setLayout(layout)
widget.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum)
self.setCentralWidget(widget)
self._plot.hide()
def _computePCA(self):
self.computeB.setEnabled(False)
try:
txt = self.maxNumComp.text()
if txt != "":
maxNumComp = float(self.maxNumComp.text())
else:
maxNumComp = None
self._thread = OperationThread(self.BSS.PCA)
self._thread.setArgs(None, maxNumComp)
self._thread.finished.connect(self._updateData)
self._thread.start()
except Exception as e:
self.computeB.setEnabled(True)
raise e
def setDataset(self, dataset):
"""
Dataset setter. Starts BSS class and initalizes thread.
Dataset setter. Starts BSS class and initializes thread.
:param Dataset dataset: dataset
"""
self.dataset = dataset
self.BSS = BSS(self.dataset.hi_data)
self._thread = OperationThread(self.BSS.PCA)
self._computePCA()
def _computePCA(self):
"""
Slot that starts the thread to compute PCA.
"""
self._thread.finished.connect(self._updateData)
self._thread.start()
self.computeB.pressed.connect(self._computePCA)
def _updateData(self):
"""
Plots the eigenvalues.
"""
self._thread.finished.disconnect(self._updateData)
self.computeB.setEnabled(True)
mean, vecs, vals = self._thread.data
vals = [item for sublist in vals for item in sublist]
self._plot.show()
self._plot.addCurve(numpy.arange(len(vals)), vals, symbol='.', linestyle=' ')
self.signalComputed.emit()
def clearStack(self):
self._sv.setStack(None)
def setStack(self, *arg, **kwargs):
"""
Sets the data passed as aguments in the stack.
Mantains the current frame showed in the view.
"""
nframe = self._sv.getFrameNumber()
self._sv.setStack(*arg, **kwargs)
self._sv.setFrameNumber(nframe)
......@@ -26,7 +26,7 @@
__authors__ = ["J. Garriga"]
__license__ = "MIT"
__date__ = "16/12/2019"
__date__ = "28/01/2020"
import numpy
......@@ -38,6 +38,7 @@ from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
import darfix
from darfix.core.blindSourceSeparation import Method, BSS
from .operationThread import OperationThread
class BSSWidget(qt.QMainWindow):
......@@ -60,6 +61,14 @@ class BSSWidget(qt.QMainWindow):
self.nComponentsLE = qt.QLineEdit("1")
self.nComponentsLE.setValidator(qt.QIntValidator())
self.computeButton = qt.QPushButton("Compute")
maxNComponentsLabel = qt.QLabel("Max number of components:")
self.maxNumComp = qt.QLineEdit("")
self.maxNumComp.setToolTip("For a specific number of components enter an "
"integer, for a\npercentage enter a float between "
"0 (included) and 1 (not included).\n"
"Float 0.5 will take as max number the 50% of "
"the images.\nEmpty text computes all components.")
self.maxNumComp.setValidator(qt.QDoubleValidator())
self.detectButton = qt.QPushButton("Detect number of components")
self.computeButton.setEnabled(False)
self.detectButton.setEnabled(False)
......@@ -70,6 +79,8 @@ class BSSWidget(qt.QMainWindow):
layout.addWidget(nComponentsLabel, 0, 2, 1, 1)
layout.addWidget(self.nComponentsLE, 0, 3, 1, 1)
layout.addWidget(self.computeButton, 0, 4, 1, 1)
layout.addWidget(maxNComponentsLabel, 1, 2, 1, 1)
layout.addWidget(self.maxNumComp, 1, 3, 1, 1)
layout.addWidget(self.detectButton, 1, 4, 1, 1)
top_widget.setLayout(layout)
......@@ -187,49 +198,75 @@ class BSSWidget(qt.QMainWindow):
Computes blind source separation with the chosen method.
"""
self.computeButton.setEnabled(False)
qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor)
try:
method = Method(self.methodCB.currentText())
n_comp = int(self.nComponentsLE.text())
if method == Method.PCA:
comp, self.W = self.BSS.PCA(n_comp)
elif method == Method.NNICA:
comp, self.W = self.BSS.non_negative_ICA(n_comp)
elif method == Method.NMF:
comp, self.W = self.BSS.NMF(n_comp)
elif method == Method.NNICA_NMF:
comp, self.W = self.BSS.NNICA_NMF(n_comp)
else:
raise ValueError('BSS method not managed')
self.comp = comp.reshape(n_comp, self.dataset.data.shape[1], self.dataset.data.shape[2])
self._sv_components.setStack(self.comp)
self._plot_rocking_curves.getPlot().clear()
self._plot_rocking_curves.getPlot().setGraphYLabel("Values")
for i in range(len(self.W.T)):
self._plot_rocking_curves.getPlot().addCurve(numpy.arange(len(self.W.T[i])), self.W.T[i], legend=str(i))
self.nComponentsLE.setEnabled(False)
method = Method(self.methodCB.currentText())
n_comp = int(self.nComponentsLE.text())
if method == Method.PCA:
self._thread = OperationThread(self.BSS.PCA)
elif method == Method.NNICA:
self._thread = OperationThread(self.BSS.non_negative_ICA)
elif method == Method.NMF:
self._thread = OperationThread(self.BSS.NMF)
elif method == Method.NNICA_NMF:
self._thread = OperationThread(self.BSS.NNICA_NMF)
else:
raise ValueError('BSS method not managed')
self._thread.setArgs(n_comp)
self._thread.finished.connect(self._setPlot)
self._thread.start()
def _setPlot(self):
self._thread.finished.disconnect(self._setPlot)
comp, self.W = self._thread.data
n_comp = int(self.nComponentsLE.text())
if comp.shape[0] < n_comp:
n_comp = comp.shape[0]
msg = qt.QMessageBox()
msg.setIcon(qt.QMessageBox.Information)
msg.setText("Found only {0} components".format(n_comp))
msg.setStandardButtons(qt.QMessageBox.Ok)
msg.exec_()
self.comp = comp.reshape(n_comp, self.dataset.data.shape[1], self.dataset.data.shape[2])
self._sv_components.setStack(self.comp)
self._plot_rocking_curves.getPlot().clear()
self._plot_rocking_curves.getPlot().setGraphYLabel("Values")
for i in range(len(self.W.T)):
self._plot_rocking_curves.getPlot().addCurve(numpy.arange(len(self.W.T[i])), self.W.T[i], legend=str(i))
if self.dataset.reshaped_data is not None:
values = self.dataset.get_dimensions_values().astype(numpy.float)
colormap = Colormap(name='jet', normalization='linear')
self._scatter_rocking_curves.setData(values[0], values[1], self.W.T[0])
self._scatter_rocking_curves.setColormap(colormap)
self._scatter_rocking_curves.getPlotWidget().setGraphXLabel(self.dataset.dims.get(0).name)
self._scatter_rocking_curves.getPlotWidget().setGraphYLabel(self.dataset.dims.get(1).name)
self._scatter_rocking_curves.resetZoom()
self._plot_rocking_curves.getPlot().setActiveCurve("0")
self.bottom_widget.show()
finally:
self.computeButton.setEnabled(True)
qt.QApplication.restoreOverrideCursor()
if self.dataset.reshaped_data is not None:
values = self.dataset.get_dimensions_values().astype(numpy.float)
colormap = Colormap(name='jet', normalization='linear')
self._scatter_rocking_curves.setData(values[0], values[1], self.W.T[0])
self._scatter_rocking_curves.setColormap(colormap)
self._scatter_rocking_curves.getPlotWidget().setGraphXLabel(self.dataset.dims.get(0).name)
self._scatter_rocking_curves.getPlotWidget().setGraphYLabel(self.dataset.dims.get(1).name)
self._scatter_rocking_curves.resetZoom()
self._plot_rocking_curves.getPlot().setActiveCurve("0")
self.bottom_widget.show()
self.computeButton.setEnabled(True)
self.nComponentsLE.setEnabled(True)
def _detectComp(self):
qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor)
mean, vecs, vals = self.BSS.PCA()
txt = self.maxNumComp.text()
if txt != "":
maxNumComp = float(self.maxNumComp.text())
else:
maxNumComp = None
self.detectButton.setEnabled(False)
self._thread = OperationThread(self.BSS.PCA)
self._thread.setArgs(None, maxNumComp)
self._thread.finished.connect(self._setNumComp)
self._thread.start()
def _setNumComp(self):
self._thread.finished.disconnect(self._setNumComp)
mean, vecs, vals = self._thread.data
vals /= numpy.sum(vals)
components = len(vals[vals > 0.01])
self.detectButton.setEnabled(True)
self.nComponentsLE.setText(str(components))
qt.QApplication.restoreOverrideCursor()
def _activeCurveChanged(self, prev_legend=None, legend=None):
if legend:
......
......@@ -32,6 +32,7 @@ __date__ = "03/12/2019"
import signal
import sys
from pathlib import Path
import glob
import numpy
......@@ -56,8 +57,7 @@ def exec_():
timer.timeout.connect(lambda: None)
w = PCAWidget()
images = glob.glob("figures/*")
images = glob.glob(str(Path(__file__).parent) + "/figures/*")
stack = []
for i, image in enumerate(images):
......
......@@ -55,7 +55,6 @@ class PCAWidgetOW(OWWidget):
super().__init__()
self._widget = PCAWidget(parent=self)
self._widget.signalComputed.connect(self._show)
self.controlArea.layout().addWidget(self._widget)
@Inputs.dataset
......@@ -63,6 +62,4 @@ class PCAWidgetOW(OWWidget):
if dataset:
self._widget.setDataset(dataset=dataset)
self.Outputs.dataset.send(dataset)
def _show(self):
self.open()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment