From 1d2881dd415967275e958169c9ea3693b4fe0fa3 Mon Sep 17 00:00:00 2001
From: payno <henri.payno@esrf.fr>
Date: Thu, 8 Feb 2024 08:29:43 +0100
Subject: [PATCH] Merge branch 'fix_8' into '1.2'

Fix NXtransformations equality test

See merge request tomotools/nxtomo!24

(cherry picked from commit 7b8cc70160e3ad8abe69cb031c4ab8866036301d)

8d55dec0 NXtransformations: fix comparaison: avoid taking into account gravity special case.
---
 nxtomo/nxobject/nxtransformations.py           |  8 +++++++-
 nxtomo/nxobject/test/test_nxtransformations.py | 10 +++++++++-
 nxtomo/utils/transformation.py                 |  2 ++
 3 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/nxtomo/nxobject/nxtransformations.py b/nxtomo/nxobject/nxtransformations.py
index cb04d47..999eb4b 100644
--- a/nxtomo/nxobject/nxtransformations.py
+++ b/nxtomo/nxobject/nxtransformations.py
@@ -243,7 +243,13 @@ class NXtransformations(NXobject):
         if not isinstance(__value, NXtransformations):
             return False
         else:
-            return self.transformations == __value.transformations
+            # to check equality we filter gravity as it can be provided at the end and as the reference
+            def is_gravity(transformation):
+                return transformation == GravityTransformation()
+
+            return list(filter(is_gravity, self.transformations)) == list(
+                filter(is_gravity, __value.transformations)
+            )
 
     @staticmethod
     def is_a_valid_group(group: h5py.Group) -> bool:
diff --git a/nxtomo/nxobject/test/test_nxtransformations.py b/nxtomo/nxobject/test/test_nxtransformations.py
index be17565..9f85d13 100644
--- a/nxtomo/nxobject/test/test_nxtransformations.py
+++ b/nxtomo/nxobject/test/test_nxtransformations.py
@@ -5,6 +5,7 @@ from nxtomo.nxobject.nxtransformations import NXtransformations
 from nxtomo.utils.transformation import (
     Transformation,
     TransformationAxis,
+    GravityTransformation,
 )
 
 
@@ -136,6 +137,13 @@ def test_nx_transforamtions(tmp_path):
     assert len(loaded_transformations.transformations) == 2
     assert loaded_transformations == nx_transformations_2
 
+    # check that Gravity will not affect the equality
+    nx_transformations_2.add_transformation(GravityTransformation())
+    assert loaded_transformations == nx_transformations_2
+
+    loaded_transformations.add_transformation(GravityTransformation())
+    assert loaded_transformations == nx_transformations_2
+
     output_file_path_2 = str(tmp_path / "test_nxtransformations.nx")
     nx_transformations_2.save(output_file_path_2, "/entry/toto/transformations")
 
@@ -143,5 +151,5 @@ def test_nx_transforamtions(tmp_path):
         output_file_path_2, "/entry/toto/transformations", 1.3
     )
     assert isinstance(loaded_transformations, NXtransformations)
-    assert len(loaded_transformations.transformations) == 2
+    assert len(loaded_transformations.transformations) == 3
     assert loaded_transformations == nx_transformations_2
diff --git a/nxtomo/utils/transformation.py b/nxtomo/utils/transformation.py
index 2cb7de3..260d334 100644
--- a/nxtomo/utils/transformation.py
+++ b/nxtomo/utils/transformation.py
@@ -189,6 +189,8 @@ class Transformation:
                 return None, "degree"
             else:
                 return transformation_values % 360, "degree"
+        elif units == "m/s2":
+            return transformation_values, "m/s2"
         else:
             converted_values = (
                 transformation_values * MetricSystem.from_str(str(units)).value
-- 
GitLab