Commit 8f915e0a authored by Thomas Vincent's avatar Thomas Vincent

add almost_equal to FitResult for tests in Windows

parent 52d9383e
......@@ -251,6 +251,17 @@ class FitResult(object):
def __eq__(self, other):
"""Implement equality, useful for tests"""
return self.almost_equal(other, rtol=0., atol=0.)
def almost_equal(self, other, rtol=1e-5, atol=1e-8):
"""Implement almost equal comparison, useful for tests
:param FitResult other: The other FitResult to compare to
:param float rtol: The relative tolerance.
See :func:`numpy.allclose` for details.
:param float atol: The absolute tolerance.
See :func:`numpy.allclose` for details.
"""
return (
isinstance(other, FitResult) and
numpy.array_equal(self.sample_x, other.sample_x) and
......@@ -260,10 +271,18 @@ class FitResult(object):
other.qspace_dimension_values)]) and
(tuple(self.qspace_dimension_names) ==
tuple(other.qspace_dimension_names)) and
numpy.array_equal(self.roi_indices, self.roi_indices) and
numpy.array_equal(self.roi_indices, other.roi_indices) and
self.fit_mode == other.fit_mode and
self.background_mode == self.background_mode and
numpy.array_equal(self._fit_results, other._fit_results))
self.background_mode == other.background_mode and
(set(self.available_result_names) ==
set(other.available_result_names)) and
numpy.all([numpy.allclose(ary1[param],
ary2[param],
rtol=rtol,
atol=atol)
for param in self.available_result_names
for ary1, ary2 in zip(self._fit_results,
other._fit_results)]))
class PeakFitter(object):
......
......@@ -33,6 +33,7 @@ __date__ = "05/01/2016"
import os
import shutil
import sys
import tempfile
import unittest
......@@ -68,6 +69,17 @@ class TestPeakFitter(ParametricTestCase):
shutil.rmtree(self._tmpTestDir)
self._tmpTestDir = None
def _assertResultAlmostEqual(self, ref, result):
"""Compare FitResults are equal on Linux and almost equal on Windows
:param FitResult ref:
:param FitResult result:
"""
if sys.platform != 'win32':
self.assertTrue(ref.almost_equal(result))
else:
self.assertEqual(ref, result)
def test_gaussian(self):
"""Test gaussian fit"""
for fit_f, qspace_f in zip(self._GAUSSIAN_FILES, self._QSPACE_FILES):
......@@ -86,12 +98,13 @@ class TestPeakFitter(ParametricTestCase):
# Compare results
ref = FitResult.from_fit_h5(
test_resources.getfile('fit_2018_12/' + fit_f))
self.assertEqual(ref, fitter.results)
self._assertResultAlmostEqual(ref, fitter.results)
# Save as HDF5 and compare
fit_out = os.path.join(self._tmpTestDir, fit_f)
fitter.results.to_fit_h5(fit_out)
self.assertEqual(ref, FitResult.from_fit_h5(fit_out))
self._assertResultAlmostEqual(ref,
FitResult.from_fit_h5(fit_out))
def test_com(self):
"""Test Center-of-mass and Max reduction"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment