Skip to content
Snippets Groups Projects
Commit 342c7ddf authored by Joao P C Bertoldo's avatar Joao P C Bertoldo
Browse files

correct a bug in zmoving medians

parent d2c2ec83
No related branches found
No related tags found
1 merge request!6Draft: Resolve "py-bkg-rm"
......@@ -6,6 +6,7 @@ Open 3D _data from h5 files and return numpy array.
import logging
from typing import Optional
import numpy as np
from numpy import ndarray
from Orange.widgets import gui
from Orange.widgets.settings import Setting
......@@ -191,7 +192,8 @@ class LoadDataFromDataUrl(OWWidget):
_logger.info("done loading data")
if self._thread.data is not None:
self.Outputs.data.send(self._thread.data)
# todo create option for this
self.Outputs.data.send(self._thread.data.astype(np.float32))
else:
self.error("data not loaded, please verify the inputs (see logs for more detail)")
......@@ -35,14 +35,17 @@ class RemoveZmovingMediansThread(QThread):
def run(self):
data_processed, self.backgrounds = remove_moving_medians(
data=self.data,
multiprocessing_rawarray=self.multiprocessing_array,
median_window=self.median_window,
median_validity=self.median_validity,
)
try:
self.data_processed, self.backgrounds = remove_moving_medians(
data=self.data,
multiprocessing_rawarray=self.multiprocessing_array,
median_window=self.median_window,
median_validity=self.median_validity,
)
self.data_processed = data_processed.copy()
except Exception as ex:
_logger.exception(ex)
self.parent.error(f"exception in `{self.__class__.__name__}`")
class RemoveZmovingMedian(OWWidget):
......@@ -60,7 +63,8 @@ class RemoveZmovingMedian(OWWidget):
# Tuple[ndarray, RawArray]
class Outputs:
shared_data = Output("shared_data", tuple) # todo document this properly
data_shared = Output("data_shared", tuple) # todo document this properly
data = Output("data", ndarray)
backgrounds = Output("backgrounds", ndarray)
median_window = Setting(500, schema_only=True)
......@@ -104,7 +108,10 @@ class RemoveZmovingMedian(OWWidget):
@Inputs.shared_data
def set_shared_data(self, shared_data):
self.data, self.multiprocessing_array = shared_data
if shared_data is None:
self.data, self.multiprocessing_array = None, None
else:
self.data, self.multiprocessing_array = shared_data
self._on_input_change()
def _on_input_change(self):
......@@ -114,8 +121,9 @@ class RemoveZmovingMedian(OWWidget):
self._thread.quit() # todo check if this works
self._computing = False
self.Outputs.data_shared.send(None)
self.Outputs.data.send(None)
self.Outputs.multiprocessing_array.send(None)
self.Outputs.backgrounds.send(None)
if self.data is None or self.multiprocessing_array is None:
self.run_button.setDisabled(True)
......@@ -159,11 +167,10 @@ class RemoveZmovingMedian(OWWidget):
_logger.info("done normalizing")
self._computing = False
if self._thread.data_normalized is not None and self._thread.backgrounds is not None:
if self._thread.data_processed is not None and self._thread.backgrounds is not None:
self.Outputs.shared_data.send(
(self._thread.data_normalized, self.multiprocessing_array)
)
self.Outputs.data_shared.send((self._thread.data_processed, self.multiprocessing_array))
self.Outputs.data.send(self._thread.data_processed)
self.Outputs.backgrounds.send(self._thread.backgrounds)
self.information("done")
......
......@@ -242,7 +242,7 @@ class RoiSelectionWidgetOW(OWWidget):
def _sendSignal(self, roi_origin=[], roi_size=[]):
"""Emits the signal with the new data."""
self.close()
self.roi_origin = list(reversed(roi_origin)) # they come as Y X
self.roi_size = list(reversed(roi_size))
self.Outputs.roi_origin.send(self.roi_origin)
self.Outputs.roi_size.send(self.roi_size)
self.roi_origin = roi_origin
self.roi_size = roi_size
self.Outputs.roi_origin.send(list(reversed(roi_origin))) # they come as Y X
self.Outputs.roi_size.send(list(reversed(roi_size)))
......@@ -40,8 +40,8 @@ class AllocateSharedMemory(OWWidget):
data = Input("data", ndarray)
class Outputs:
data_shared = Output("data_shared", ndarray)
multiprocessing_array = Output("multiprocessing_array", ndarray)
data_shared = Output("data_shared", tuple) # todo document this properly
data = Output("data", ndarray) # todo document this properly
def __init__(self):
super().__init__()
......@@ -68,6 +68,8 @@ class AllocateSharedMemory(OWWidget):
if data is None:
self.run_button.setDisabled(True)
self.Outputs.data_shared.send(None)
self.Outputs.data.send(None)
self.information()
else:
......@@ -111,8 +113,8 @@ class AllocateSharedMemory(OWWidget):
self.data_shared = self._thread.data_shared
self.multiprocessing_array = self._thread.multiprocessing_array
self.Outputs.data_shared.send(self.data_shared)
self.Outputs.multiprocessing_array.send(self.multiprocessing_array)
self.Outputs.data_shared.send((self.data_shared, self.multiprocessing_array))
self.Outputs.data.send(self.data_shared)
else:
self.error("something went wrong and the thread did not return the results properly")
......
......@@ -12,13 +12,15 @@ from typing import Iterable, List, Optional, Tuple
import _ctypes
import numpy as np
from numpy import ndarray
from silx.io import get_data
from silx.io import open as silx_open
from silx.io.url import DataUrl
# this is copied from silx.io.dictdump (line 266 in version 0.16.0 dev serial 0)
# PR asking to put this in a constant in the module itself
# https://github.com/silx-kit/silx/pull/3460
from pydct import log
from silx.io import get_data
from silx.io import open as silx_open
from silx.io.url import DataUrl
UPDATE_MODE_VALID_EXISTING_VALUES = ("add", "replace", "modify")
......@@ -27,9 +29,9 @@ update_mode: same values from silx.io.dictdump.dicttoh5 (files are saved in file
defaults to 'add'
Behavior:
- 'modify' will overwrite existing _data, but not the whole file nor the entire group, only the affected stuff
- 'modify' will overwrite existing data, but not the whole file nor the entire group, only the affected stuff
- 'add' will NOT overwrite, only create new stuff without modifying other links in the group;
- 'replace' will replace the whole _data tree (group pointed by the link) in the h5, but not the entire file
- 'replace' will replace the whole data tree (group pointed by the link) in the h5, but not the entire file
Reference: http://www.silx.org/doc/silx/latest/modules/io/dictdump.html#silx.io.dictdump.dicttoh5
"""
......@@ -119,7 +121,7 @@ def catch_silx_io_exceptions(get_data_func, logger_=None):
@functools.wraps(get_data_func)
def wrapper(data_url: DataUrl):
logger_.debug("loading _data")
logger_.debug("loading data")
try:
data = get_data_func(data_url)
......@@ -147,12 +149,12 @@ def catch_silx_io_exceptions(get_data_func, logger_=None):
logger_.exception(ex)
msg = ex.args[0]
# specify the cause when the _data path is missing
# specify the cause when the data path is missing
if msg == "Argument 'path' must not be None":
raise ValueError(f"{data_url.path()=} missing _data path (h5 internal link after '::')") from ex
raise ValueError(f"{data_url.path()=} missing data path (h5 internal link after '::')") from ex
if msg == "expected bytes, NoneType found":
raise ValueError(f"{data_url.path()=} missing _data path (h5 internal link after '::')") from ex
raise ValueError(f"{data_url.path()=} missing data path (h5 internal link after '::')") from ex
raise
......@@ -165,7 +167,7 @@ def catch_silx_io_exceptions(get_data_func, logger_=None):
if "Data path from URL" in msg and "not found" in msg:
raise BadDataUrlError(f"{data_url.data_path()=} not in {data_url.file_path()=}") from ex
logger_.debug(f"_data successfully loaded with `{get_data_func.__name__}`")
logger_.debug(f"data successfully loaded with `{get_data_func.__name__}`")
return data
......@@ -184,6 +186,7 @@ class namespace2kwargs:
def __call__(self, main_func):
@functools.wraps(main_func)
def wrapper(args):
"""args is a namespace from argparse"""
kwargs = {k: v for k, v in vars(args).items() if k in self.kwargs_keys}
return main_func(**kwargs)
......@@ -191,9 +194,9 @@ class namespace2kwargs:
def data2shared(data: ndarray, multiprocessing_rawarray: Optional[RawArray]) -> Tuple[ndarray, RawArray]:
"""Check if the _data is already in a shared memor space or return such if not.
"""Check if the data is already in a shared memor space or return such if not.
This is to avoid recopying the _data to a cross-process shared space, a very long operation.
This is to avoid recopying the data to a cross-process shared space, a very long operation.
"""
if data.ndim != 3:
......@@ -228,6 +231,8 @@ def data2shared(data: ndarray, multiprocessing_rawarray: Optional[RawArray]) ->
else:
mp_data_shared = multiprocessing_rawarray
print(f"in data2shared {id(data)=} {id(mp_data_shared)=}")
return data, mp_data_shared
......
......@@ -88,7 +88,7 @@ def main(
Do the entire preprocessing:
- todo list steps
!!IMPORTANT!! The _data volume's axes order is assumed to be (z, x, y).
!!IMPORTANT!! The data volume's axes order is assumed to be (z, x, y).
Relevant functions called:
......@@ -107,24 +107,24 @@ def main(
https://stackoverflow.com/questions/53751050/python-multiprocessing-understanding-logic-behind-chunksize
IO: reads _data with `silx.io.get_data` and writes with `silx.io.dictdump.dicttoh5`
IO: reads data with `silx.io.get_data` and writes with `silx.io.dictdump.dicttoh5`
Args:
input_url: url (scheme + path + link) to READ the INPUT _data
file and its internal link with the _data
ex: silx:/_data/id11/my.h5::/2.1/measurement/marana
input_url: url (scheme + path + link) to READ the INPUT data
file and its internal link with the data
ex: silx:/data/id11/my.h5::/2.1/measurement/marana
obs1: 'silx:' indicates the h5 scheme
obs2: silx (the library) also supports 'fabio', but that is not supported here
dark_url: like `url`, but for the dark image
output_url: where the modified _data is dumped to (an h5 group)
output_url: where the modified data is dumped to (an h5 group)
url (scheme + path + link) to WRITE the OUTPUT (see `url`)
the outputs are dump from a dict into an h5 group (see internal class Outputs)
normalization: method used to compensate beam oscillations
normalization_numerator: an arbitrary value used as numerator when normalizing the _data
normalization_numerator: an arbitrary value used as numerator when normalizing the data
margin_bounding_box: margin region used to compute the mean when using normalization method ~margin mean~
upper left and bottom right coordinates given as (x, y)
......@@ -189,19 +189,21 @@ def main(
raise ValueError(f"{median_window=} must be positive and at most {nz=}")
if (darknx, darkny) != (nx, ny):
raise ValueError(f"incompatible dark/_data shapes on XY {dark_shape=} {data_shape=}")
raise ValueError(f"incompatible dark/data shapes on XY {dark_shape=} {data_shape=}")
# ul = upper left, br = bottom right
((ulx, uly), (brx, bry)) = margin_bounding_box
if margin_bounding_box is not None:
# ul = upper left, br = bottom right
((ulx, uly), (brx, bry)) = margin_bounding_box
for coord, val, upper_limit in [
("margin upper left x", ulx, nx),
("margin bottom right x", brx, nx),
("margin upper left y", uly, ny),
("margin bottom right y", bry, ny),
]:
if not (0 <= val < upper_limit):
raise ValueError(f"{coord} incompatible {val=} {upper_limit=}")
for coord, val, upper_limit in [
("margin upper left x", ulx, nx),
("margin bottom right x", brx, nx),
("margin upper left y", uly, ny),
("margin bottom right y", bry, ny),
]:
if not (0 <= val < upper_limit):
raise ValueError(f"{coord} incompatible {val=} {upper_limit=}")
kernelx, kernely = filter_dimensions
......@@ -220,17 +222,17 @@ def main(
if update_mode not in common.UPDATE_MODE_VALID_EXISTING_VALUES:
raise ValueError(f"{update_mode=} not valid; pick one from {common.UPDATE_MODE_VALID_EXISTING_VALUES}")
if nprocs < 1:
if nprocs is not None and nprocs < 1:
raise ValueError(f"{nprocs=} must be >= 1")
logger.debug("the args look good to go")
# ============================ _data ============================
# ============================ data ============================
logger.info("loading dark")
dark = common.catch_silx_io_exceptions(get_data, logger)(dark_url)
logger.debug("validating dark")
dark = common.validate_3d_data(dark)
common.validate_3d_data(dark)
if dark.dtype != np.float32:
logger.warning(f"converting {dark.dtype=} to {np.float32.__name__}")
......@@ -238,17 +240,17 @@ def main(
logger.debug("the dark image looks fine")
logger.info("loading _data")
logger.info("loading data")
data = common.catch_silx_io_exceptions(get_data, logger)(input_url)
logger.debug("validating _data")
data = common.validate_3d_data(data)
logger.debug("validating data")
common.validate_3d_data(data)
if data.dtype != np.float32:
logger.warning(f"converting {data.dtype=} to {np.float32.__name__}")
data = data.astype(np.float32)
logger.debug("the _data looks fine")
logger.debug("the data looks fine")
# ============================ processing ============================
......@@ -257,7 +259,7 @@ def main(
logger.debug("squeezing the dark volume on the z-axis")
dark = np.squeeze(np.mean(dark, axis=0)) # todo verify if this is median or mean
logger.debug("subtract dark from _data")
logger.debug("subtract dark from data")
data = data - dark
if normalization is None:
......@@ -294,7 +296,7 @@ def main(
data=data,
backgrounds=background_medians if save_backgrounds else None,
parameters=dict(
normalization=normalization.value,
normalization=None if normalization is None else normalization.value,
normalization_numerator=normalization_numerator,
margin_bounding_box=margin_bounding_box,
median_validity=median_validity,
......@@ -366,16 +368,16 @@ preprocess_parser.add_argument(
"--in",
type=DataUrl,
metavar="silx:input.h5::/3d_data",
help="url (scheme + path + link) to the INPUT _data",
help="url (scheme + path + link) to the INPUT data",
required=True,
dest="url",
dest="input_url",
)
preprocess_parser.add_argument(
"--dark",
type=DataUrl,
metavar="silx:darkend.h5::/3d_data",
help="url (scheme + path + link) to the DARK _data (no beam)",
help="url (scheme + path + link) to the DARK data (no beam)",
required=True,
dest="dark_url",
)
......@@ -423,6 +425,7 @@ preprocess_parser.add_argument(
type=bounding_box_2d,
metavar="MBB",
help="upper left (UL) and bottom right (BR) corners (x, y) of a 2D bounding box of the margin: ULx,ULy,BRx,BRy",
dest="margin_bounding_box", # must be present so the kwargs shows up in the namespace
)
preprocess_parser.add_argument(
......@@ -462,6 +465,7 @@ preprocess_parser.add_argument(
metavar="FD",
default=(3, 3),
type=kernel_shape_2d,
dest="filter_dimensions", # must be present so the kwargs shows up in the namespace
)
preprocess_parser.add_argument(
......@@ -497,19 +501,20 @@ class ExamplesAction(ExampleCallsActionAbstract):
@property
def examples(self) -> List[str]:
filename = Path(__file__).name
in_url_str = (
"silx:/_data/id11/3dxrd/blc12852/id11/bmg_l1/bmg_l1_bmg_dct2/bmg_l1_bmg_dct2.h5::/10.1/measurement/marana "
)
out_url_str = "silx:/tmp/output.h5::/_data"
small_in_url_str = "silx:/data/id11/3dxrd/blc12852/id11/bmg_l1/bmg_l1_bmg_dct2/scan0002/marana_0000.h5::/entry_0000/ESRF-ID11/marana/data"
in_url_str = "silx:/data/id11/3dxrd/blc12852/id11/bmg_l1/bmg_l1_bmg_dct2/bmg_l1_bmg_dct2.h5::/2.1/measurement/marana "
out_url_str = "silx:/tmp/output.h5::/data"
dark_url_str = (
"silx:/_data/id11/3dxrd/blc12852/id11/bmg_l1/bmg_l1_bmg_dct2/bmg_l1_bmg_dct2.h5::/10.1/measurement/marana"
"silx:/data/id11/3dxrd/blc12852/id11/bmg_l1/bmg_l1_bmg_dct2/bmg_l1_bmg_dct2.h5::/10.1/measurement/marana"
)
return [
f"{filename} --help",
f"{filename} --docstrings",
f"{filename} -vv --in {in_url_str} --dark {dark_url_str} --out {out_url_str}",
f"{filename} -vv --in {in_url_str} --dark {dark_url_str} --out {out_url_str} "
f"python {filename} {COMMAND_PREPROCESS} --help",
f"python {filename} {COMMAND_PREPROCESS} --docstrings",
f"python {filename} {COMMAND_PREPROCESS} -vv --in {in_url_str} --dark {dark_url_str} --out {out_url_str}",
f"python {filename} {COMMAND_PREPROCESS} -vv --in {in_url_str} --dark {dark_url_str} --out {out_url_str} "
"--normalization margin_mean --margin_bounding_box 100,100,200,200 --normalization_numerator 100",
f"python {filename} {COMMAND_PREPROCESS} -vv --in {small_in_url_str} --dark {dark_url_str} --out {out_url_str} "
"--moving_median_validity 5 --moving_median_window 50 --save_backgrounds --update_mode replace",
]
......
......@@ -106,7 +106,9 @@ def _worker_remove_median(zwindow_idx, zvalidity_start, zvalidity_stop):
global data_, medians_
data_[zvalidity_start:zvalidity_stop, :, :] = data_[zvalidity_start:zvalidity_stop, :, :] - medians_[zwindow_idx, :, :]
data_[zvalidity_start:zvalidity_stop, :, :] = (
data_[zvalidity_start:zvalidity_stop, :, :] - medians_[zwindow_idx : (zwindow_idx + 1), :, :]
)
def remove_moving_medians(
......@@ -184,6 +186,8 @@ def remove_moving_medians(
medians_shared = np.frombuffer(mp_medians_shared, dtype=data.dtype).reshape((nmedians, nx, ny))
logger.debug("computing and removing the z-wise moving medians")
# the closing is initializing the children processes
with contextlib.closing(
mp.Pool(
......@@ -204,14 +208,28 @@ def remove_moving_medians(
# from a slice could make another's computation go wrong
# btw, it's important to do the operation in-place becasue the memory allocation
# takes a considerable time
logger.debug("computing the medians")
pool.starmap_async(
_worker_compute_median,
((idx, window_bounds[0, idx], window_bounds[1, idx]) for idx in range(nmedians)),
)
logger.debug("subtracting medians")
pool.join()
# todo test this (values should match a theoretical example)
with contextlib.closing(
mp.Pool(
processes=nprocs,
initializer=_worker_init,
initargs=(
mp_data_shared,
mp_medians_shared,
data.shape,
data.dtype,
nmedians,
),
)
) as pool:
pool.starmap_async(
_worker_remove_median,
((idx, validity_bounds[0, idx], validity_bounds[1, idx]) for idx in range(nmedians)),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment