Skip to content
Snippets Groups Projects
Commit d82694c6 authored by payno's avatar payno
Browse files

Merge branch 'update_jp2k_volume' into '1.3'

JP2KVolume: add data cast and clipping

See merge request !155
parents b196ee60 a6e2ce7d
No related branches found
No related tags found
1 merge request!155JP2KVolume: add data cast and clipping
Pipeline #130200 passed
Change Log
==========
1.3.2: 2023/08/23
-----------------
* volume
* jp2k: rescale data if cast it to uint16 by default
1.3.0: 2023/08/01
-----------------
......
......@@ -41,6 +41,7 @@ from silx.io.url import DataUrl
from tomoscan.esrf.identifier.jp2kidentifier import JP2KVolumeIdentifier
from tomoscan.scanbase import TomoScanBase
from tomoscan.utils import docstring
from tomoscan.utils.volume import rescale_data
from .singleframebase import VolumeSingleFrameBase
......@@ -77,6 +78,9 @@ class JP2KVolume(VolumeSingleFrameBase):
This defines a quality metric for lossy compression.
The number "0" stands for lossless compression.
:param Optional[int] n_threads: number of thread to use for writing. If None will try to get as much as possible
:param Optional[tuple] clip_values: optional tuple of two float (min, max) to clamp volume value
:param bool rescale_data: rescale data before dumping each frame. Expected to be True when dump a new volume and False
when save volume cast for example (and when histogram is know...)
:warning: each file saved under {volume_basename}_{index_zfill6}.jp2k is considered to be a slice of the volume.
"""
......@@ -101,6 +105,8 @@ class JP2KVolume(VolumeSingleFrameBase):
cratios: Optional[list] = None,
psnr: Optional[list] = None,
n_threads: Optional[int] = None,
clip_values: Optional[tuple] = None,
rescale_data: bool = True,
) -> None:
if folder is not None:
url = DataUrl(
......@@ -132,7 +138,13 @@ class JP2KVolume(VolumeSingleFrameBase):
)
self._cratios = cratios
self._psnr = psnr
self._clip_values = None
self._cast_already_logged = False
"""bool used to avoid logging potential data cast for each frame when save a volume"""
self._rescale_data = rescale_data
"""should we rescale data before dumping the frame"""
self.setup_multithread_encoding(n_threads=n_threads)
self.clip_values = clip_values # execute test about the type...
@property
def cratios(self) -> Optional[list]:
......@@ -150,11 +162,69 @@ class JP2KVolume(VolumeSingleFrameBase):
def psnr(self, psnr: Optional[list]):
self._psnr = psnr
@property
def rescale_data(self) -> bool:
return self._rescale_data
@rescale_data.setter
def rescale_data(self, rescale: bool) -> None:
if not isinstance(rescale, bool):
raise TypeError
self._rescale_data = rescale
@property
def clip_values(self) -> Optional[tuple]:
"""
:return: optional min and max value to clip - as float.
:rtype: Optional[tuple]
"""
return self._clip_values
@clip_values.setter
def clip_values(self, values: Optional[tuple]) -> None:
if values is None:
self._clip_values = None
elif not isinstance(values, (tuple, list)):
raise TypeError
elif not len(values) == 2:
raise ValueError("clip values are expected to be two floats")
elif not values[1] >= values[0]:
raise ValueError
else:
self._clip_values = values
@docstring(VolumeSingleFrameBase)
def save_data(self, url: Optional[DataUrl] = None) -> None:
self._cast_already_logged = False
super().save_data(url=url)
@docstring(VolumeSingleFrameBase)
def save_frame(self, frame, file_name, scheme):
if not has_glymur:
raise RuntimeError(_MISSING_GLYMUR_MSG)
if self.clip_values is not None:
frame = numpy.clip(frame, self.clip_values[0], self.clip_values[1])
if self.rescale_data:
if frame.dtype in (numpy.uint8, numpy.uint16):
max_uint = numpy.iinfo(frame.dtype).max
else:
max_uint = numpy.iinfo(numpy.uint16).max
frame = rescale_data(
data=frame,
new_min=0,
new_max=max_uint,
data_min=self.clip_values[0] if self.clip_values is not None else None,
data_max=self.clip_values[1] if self.clip_values is not None else None,
)
if not isinstance(frame.dtype, (numpy.uint8, numpy.uint16)):
if self._cast_already_logged:
self._cast_already_logged = True
_logger.info(
f"{self.get_identifier().to_str()} get {frame.dtype}. Cast it as {numpy.uint16}"
)
frame = frame.astype(numpy.uint16)
if scheme == "glymur":
glymur.Jp2k(file_name, data=frame, psnr=self.psnr, cratios=self.cratios)
else:
......
"""specific test for the jp2kvolume. the large part is in test_single_frame_volume as most of the processing is common"""
import os
import numpy
from tomoscan.esrf.volume.jp2kvolume import JP2KVolume
from tomoscan.esrf.volume.mock import create_volume
_data = create_volume(
frame_dims=(100, 100), z_size=11
) # z_size need to be at least 10 to check loading from file name works
for i in range(len(_data)):
_data[i] += 1
_data = _data.astype(numpy.uint16)
def test_jp2kvolume_rescale(tmp_path):
"""
Test that rescale is correctly applied by default by the JP2KVolume
"""
acquisition_dir = tmp_path / "acquisition"
os.makedirs(acquisition_dir)
volume_dir = str(acquisition_dir / "volume")
os.makedirs(volume_dir)
volume = JP2KVolume(folder=volume_dir, data=_data, metadata={})
volume.save()
volume.clear_cache()
volume.load()
assert volume.data.min() == 0
assert volume.data.max() == numpy.iinfo(numpy.uint16).max
......@@ -143,6 +143,8 @@ def test_create_volume_from_folder(tmp_path, volume_constructor):
volume.save()
volume.overwrite = True
if isinstance(volume, JP2KVolume):
volume.rescale_data = False
volume.save()
# check load data and metadata
......@@ -202,6 +204,8 @@ def test_data_file_saver_generator(tmp_path, volume_constructor):
volume_dir = str(tmp_path / "volume")
os.makedirs(volume_dir)
volume = volume_constructor(folder=volume_dir)
if isinstance(volume, JP2KVolume):
volume.rescale_data = False
for slice_, slice_saver in zip(
_data,
volume.data_file_saver_generator(
......@@ -232,6 +236,11 @@ def test_several_writer(tmp_path, volume_constructor):
metadata=_metadata,
)
volume_2 = volume_constructor(folder=volume_dir, data=_data[5:], start_index=5)
if isinstance(volume_1, JP2KVolume):
volume_1.rescale_data = (
False # keep coherence between all the volumes. Simplify test
)
volume_2.rescale_data = False
volume_1.save()
volume_2.save()
......@@ -258,6 +267,10 @@ def test_volume_identifier(tmp_path, volume_constructor):
data=_data,
metadata=_metadata,
)
if isinstance(volume, JP2KVolume):
volume.rescale_data = (
False # keep coherence between all the volumes. Simplify test
)
volume.save()
identifier = volume.get_identifier()
assert isinstance(identifier, VolumeIdentifier)
......@@ -324,6 +337,10 @@ def test_volume_with_prefix(tmp_path, volume_constructor):
metadata=_metadata,
volume_basename=file_prefix,
)
if isinstance(volume_1, JP2KVolume):
volume_1.rescale_data = (
False # keep coherence between all the volumes. Simplify test
)
volume_1.save()
full_volume = volume_constructor(folder=volume_dir, volume_basename=file_prefix)
......
......@@ -101,6 +101,8 @@ def test_concatenate_volume(tmp_path, volume_class, axis):
else:
vol_params.update({"folder": os.path.join(raw_data_dir, f"volume_{i_vol}")})
volume = volume_class(**vol_params)
if isinstance(volume, JP2KVolume):
volume.rescale_data = False # simplify test
volume.save()
volumes.append(volume)
volume.data = None
......@@ -117,6 +119,8 @@ def test_concatenate_volume(tmp_path, volume_class, axis):
final_volume = volume_class(
folder=os.path.join(output_dir, "final_vol"),
)
if isinstance(final_volume, JP2KVolume):
final_volume.rescale_data = False
concatenate(output_volume=final_volume, volumes=volumes, axis=axis)
if axis == 0:
......
......@@ -47,6 +47,13 @@ def concatenate(output_volume: VolumeBase, volumes: tuple, axis: int) -> None:
if len(invalids) > 0:
raise ValueError(f"Several non-volumes found. ({invalids})")
from tomoscan.esrf.volume.jp2kvolume import JP2KVolume # avoid cyclic import
if isinstance(output_volume, JP2KVolume) and output_volume.rescale_data is True:
_logger.warning(
"concatenation will rescale data frame. If you want to avoid this please set output volume 'rescale_data' to False"
)
# 1. compute final shape
def get_volume_shape():
if axis == 0:
......@@ -211,3 +218,11 @@ def update_metadata(ddict_1: dict, ddict_2: dict) -> dict:
else:
ddict_1[key] = value
return ddict_1
def rescale_data(data, new_min, new_max, data_min=None, data_max=None):
if data_min is None:
data_min = numpy.min(data)
if data_max is None:
data_max = numpy.max(data)
return (new_max - new_min) / (data_max - data_min) * (data - data_min) + new_min
......@@ -77,7 +77,7 @@ RELEASE_LEVEL_VALUE = {
MAJOR = 1
MINOR = 3
MICRO = 1
MICRO = 2
RELEV = "final" # <16
SERIAL = 4 # <16
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment