From 1d2881dd415967275e958169c9ea3693b4fe0fa3 Mon Sep 17 00:00:00 2001 From: payno <henri.payno@esrf.fr> Date: Thu, 8 Feb 2024 08:29:43 +0100 Subject: [PATCH] Merge branch 'fix_8' into '1.2' Fix NXtransformations equality test See merge request tomotools/nxtomo!24 (cherry picked from commit 7b8cc70160e3ad8abe69cb031c4ab8866036301d) 8d55dec0 NXtransformations: fix comparaison: avoid taking into account gravity special case. --- nxtomo/nxobject/nxtransformations.py | 8 +++++++- nxtomo/nxobject/test/test_nxtransformations.py | 10 +++++++++- nxtomo/utils/transformation.py | 2 ++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/nxtomo/nxobject/nxtransformations.py b/nxtomo/nxobject/nxtransformations.py index cb04d47..999eb4b 100644 --- a/nxtomo/nxobject/nxtransformations.py +++ b/nxtomo/nxobject/nxtransformations.py @@ -243,7 +243,13 @@ class NXtransformations(NXobject): if not isinstance(__value, NXtransformations): return False else: - return self.transformations == __value.transformations + # to check equality we filter gravity as it can be provided at the end and as the reference + def is_gravity(transformation): + return transformation == GravityTransformation() + + return list(filter(is_gravity, self.transformations)) == list( + filter(is_gravity, __value.transformations) + ) @staticmethod def is_a_valid_group(group: h5py.Group) -> bool: diff --git a/nxtomo/nxobject/test/test_nxtransformations.py b/nxtomo/nxobject/test/test_nxtransformations.py index be17565..9f85d13 100644 --- a/nxtomo/nxobject/test/test_nxtransformations.py +++ b/nxtomo/nxobject/test/test_nxtransformations.py @@ -5,6 +5,7 @@ from nxtomo.nxobject.nxtransformations import NXtransformations from nxtomo.utils.transformation import ( Transformation, TransformationAxis, + GravityTransformation, ) @@ -136,6 +137,13 @@ def test_nx_transforamtions(tmp_path): assert len(loaded_transformations.transformations) == 2 assert loaded_transformations == nx_transformations_2 + # check that Gravity will not affect the equality + nx_transformations_2.add_transformation(GravityTransformation()) + assert loaded_transformations == nx_transformations_2 + + loaded_transformations.add_transformation(GravityTransformation()) + assert loaded_transformations == nx_transformations_2 + output_file_path_2 = str(tmp_path / "test_nxtransformations.nx") nx_transformations_2.save(output_file_path_2, "/entry/toto/transformations") @@ -143,5 +151,5 @@ def test_nx_transforamtions(tmp_path): output_file_path_2, "/entry/toto/transformations", 1.3 ) assert isinstance(loaded_transformations, NXtransformations) - assert len(loaded_transformations.transformations) == 2 + assert len(loaded_transformations.transformations) == 3 assert loaded_transformations == nx_transformations_2 diff --git a/nxtomo/utils/transformation.py b/nxtomo/utils/transformation.py index 2cb7de3..260d334 100644 --- a/nxtomo/utils/transformation.py +++ b/nxtomo/utils/transformation.py @@ -189,6 +189,8 @@ class Transformation: return None, "degree" else: return transformation_values % 360, "degree" + elif units == "m/s2": + return transformation_values, "m/s2" else: converted_values = ( transformation_values * MetricSystem.from_str(str(units)).value -- GitLab