shift.py 3.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# coding: utf-8
# /*##########################################################################
#
# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/


__authors__ = ["H. Payno"]
__license__ = "MIT"
__date__ = "15/10/2018"


import logging
from . import OverwritingOperation
from collections import namedtuple
from id06workflow.core import image
import numpy

_logger = logging.getLogger(__file__)

ShiftValue = namedtuple('ShiftValue', ['dx', 'dy', 'entry', 'grid_shift'])


class Shift(OverwritingOperation):
    """
    Apply a simple shift on the data. Overwrite data
    
    :param float dx: x translation between two images
    :param float dy: y translation between two images
    :param numpy.ndarray data: data to be shifted

    :return: shifted data
    :rtype: numpy.ndarray    
    """
    def __init__(self, experiment, dx=0.0, dy=0.0, dz=0):
        OverwritingOperation.__init__(self, experiment, name='shift')
        if dz != 0:
            raise NotImplementedError('z shift not taken into account yet')
        assert self.data.ndim is 3
        self.dx = dx
        self.dy = dy
        self.dz = dz
        self._cache_data = None

    @property
    def dx(self):
        return self._dx

    @dx.setter
    def dx(self, dx):
        self._dx = dx

    @property
    def dy(self):
        return self._dy

    @dy.setter
    def dy(self, dy):
        self._dy = dy

    @property
    def dz(self):
        return self._dz

    @dx.setter
    def dz(self, dz):
        self._dz = dz

    def compute(self):
89
90
91
        self.data = self._compute(self.data)
        self.registerOperation()
        return self.data
92
93
94
95
96
97

    def dry_run(self, cache_data=None):
        if cache_data is None:
            self._cache_data = self.data[...]
        else:
            self._cache_data = cache_data
98
99
        self._cache_data = self._compute(self._cache_data)
        return self._cache_data
100
101
102
103
104
105
106
107
108
109
110
111
112
113

    def apply(self):
        if self._cache_data is None:
            raise ValueError('No data in cache')
        self.data = self._cache_data
        self.clear_cache()
        self.registerOperation()
        return self.data

    def clear_cache(self):
        self._cache_data = None

    def _compute(self, data):
        res = []
114
        nImg = data.shape[0]
115
116
        for iImg, img in enumerate(self.data[:]):
            self.updateProgress(int(iImg / nImg * 100.0))
117
118
119
120
121
122
123
124
125
            _dx = self.dx * iImg
            _dy = self.dy * iImg
            res.append(image.shift_img(img, dx=_dx, dy=_dy))
        data = numpy.asarray(res)
        return data

    def key(self):
        return ' '.join((self._name, 'dx:', str(self.dx), 'dy:', str(self.dy),
                         'dz:', str(self.dz)))