Commit f80f9fce authored by Wout De Nolf's avatar Wout De Nolf
Browse files

greenlet_utils: refactor to make things more explicit

parent 28de930f
Pipeline #49015 passed with stages
in 114 minutes and 25 seconds
......@@ -7,43 +7,59 @@ from gevent import monkey
from gevent import greenlet, timeout
import gevent
MASKED_GREENLETS = dict()
GREENLET_MASK_STATE = dict()
class KillMask:
"""All exceptions which are the result of a `kill(SomeException)`
call on the greenlet in which the KillMask context is entered,
will be delayed until context exit, except for `gevent.Timeout`.
Upon exit, only the last capture exception is re-raised.
Optionally we can set a limit to the number of `kill` calls
will be delayed.
Warning: this does not delay interrupts.
"""
def __init__(self, masked_kill_nb=-1):
"""
masked_kill_nb: nb of masked kill
< 0 mean all kills are masked.
if > 0, at each kill attempt the counter decrements until 0, then the greenlet can be killed
masked_kill_nb: number of masked `kill` calls that will be delayed
> 0 this ammount of kills will be delayed
== 0 no kill is delayed
< 0 unlimited
"""
self.__greenlet = gevent.getcurrent()
self.__kill_counter = masked_kill_nb
self.__masked_kill_nb = masked_kill_nb
self.__allowed_kills = masked_kill_nb
self.__last_captured_exception = None
def __enter__(self):
self.__exception = None
MASKED_GREENLETS.setdefault(self.__greenlet, set()).add(self)
self.__allowed_kills = self.__masked_kill_nb
self.__last_captured_exception = None
GREENLET_MASK_STATE.setdefault(self.__greenlet, set()).add(self)
def __exit__(self, exc_type, value, traceback):
MASKED_GREENLETS[self.__greenlet].remove(self)
if MASKED_GREENLETS[self.__greenlet]:
GREENLET_MASK_STATE[self.__greenlet].remove(self)
if GREENLET_MASK_STATE[self.__greenlet]:
return
MASKED_GREENLETS.pop(self.__greenlet)
if self.__exception is not None:
raise self.__exception
GREENLET_MASK_STATE.pop(self.__greenlet)
if self.__last_captured_exception is not None:
raise self.__last_captured_exception
@property
def exception(self):
return self.__exception
def last_captured_exception(self):
return self.__last_captured_exception
def set_throw(self, exception):
if self.__kill_counter:
self.__exception = exception
else: # reach 0
self.__exception = None
cnt = self.__kill_counter
self.__kill_counter -= 1
return not cnt
def capture_exception(self, exception):
capture = bool(self.__allowed_kills)
if capture:
self.__last_captured_exception = exception
else:
self.__last_captured_exception = None
self.__allowed_kills -= 1
return capture
@contextmanager
......@@ -52,74 +68,89 @@ def AllowKill():
This will unmask the kill protection for the current greenlet.
"""
current_greenlet = gevent.getcurrent()
previous_set_mask = MASKED_GREENLETS.pop(current_greenlet, set())
kill_masks = GREENLET_MASK_STATE.pop(current_greenlet, set())
try:
for killmask in previous_set_mask:
if killmask.exception:
raise killmask.exception
for kill_mask in kill_masks:
if kill_mask.last_captured_exception:
raise kill_mask.last_captured_exception
yield
finally:
if previous_set_mask:
MASKED_GREENLETS[current_greenlet] = previous_set_mask
if kill_masks:
GREENLET_MASK_STATE[current_greenlet] = kill_masks
def protect_from_kill(fu):
@wraps(fu)
def func(*args, **kwargs):
def protect_from_kill(method):
@wraps(method)
def wrapper(*args, **kwargs):
with KillMask():
return fu(*args, **kwargs)
return method(*args, **kwargs)
return func
return wrapper
def protect_from_one_kill(fu):
@wraps(fu)
def func(*args, **kwargs):
def protect_from_one_kill(method):
@wraps(method)
def wrapper(*args, **kwargs):
with KillMask(masked_kill_nb=1):
return fu(*args, **kwargs)
return method(*args, **kwargs)
return func
return wrapper
# gevent.greenlet module patch
_ori_timeout = gevent.timeout.Timeout
_GeventTimeout = gevent.timeout.Timeout
_GeventGreenlet = greenlet.Greenlet
class Greenlet(greenlet.Greenlet):
class BlissGreenlet(_GeventGreenlet):
"""KillMask can only work when entered in a greenlet of type `BlissGreenlet`
"""
def throw(self, exception):
if isinstance(exception, gevent.timeout.Timeout):
if isinstance(exception, _GeventTimeout):
return super().throw(exception)
masks = MASKED_GREENLETS.get(self)
if masks:
for m in list(masks):
if m.set_throw(exception):
super().throw(exception)
else:
super().throw(exception)
kill_masks = GREENLET_MASK_STATE.get(self)
if kill_masks:
captured_in_all_masks = True
for kill_mask in list(kill_masks):
captured_in_all_masks &= kill_mask.capture_exception(exception)
if captured_in_all_masks:
return
super().throw(exception)
def get(self, *args, **keys):
try:
return super().get(*args, **keys)
except _ori_timeout as tmout:
t = Timeout(exception=tmout.exception)
raise t
except BlissTimeout:
raise
except _GeventTimeout as tmout:
raise BlissTimeout(exception=tmout.exception)
# timeout patch
class Timeout(gevent.timeout.Timeout):
class BlissTimeout(_GeventTimeout):
"""KillMask can only work when timeouts are of type `BlissGreenlet`
"""
def _on_expiration(self, prev_greenlet, ex):
if isinstance(prev_greenlet, Greenlet): # bliss greenlet
super(Greenlet, prev_greenlet).throw(ex)
else: # default
if isinstance(prev_greenlet, BlissGreenlet):
# Make sure the exception is not captured
super(BlissGreenlet, prev_greenlet).throw(ex)
else:
prev_greenlet.throw(ex)
def patch_gevent():
asyncio.set_event_loop_policy(aiogevent.EventLoopPolicy())
monkey.patch_all(thread=False)
gevent.spawn = Greenlet.spawn
gevent.spawn_later = Greenlet.spawn_later
timeout.Timeout = Timeout
gevent.Timeout = Timeout
gevent.spawn = BlissGreenlet.spawn
gevent.spawn_later = BlissGreenlet.spawn_later
timeout.Timeout = BlissTimeout
gevent.Timeout = BlissTimeout
# For backward compatitibilty
Greenlet = BlissGreenlet
Timeout = BlissTimeout
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment