Commit d8730558 authored by Wout De Nolf's avatar Wout De Nolf
Browse files

greenlet_utils: avoid globals and avoid keeping references to greenlet objects

We cannot use gevent.local.local because BlissGreenlet.throw is executed in the hub
parent f80f9fce
Pipeline #49029 passed with stages
in 135 minutes and 28 seconds
from contextlib import contextmanager
from functools import wraps
from contextlib import contextmanager
import logging
import asyncio
import aiogevent
......@@ -7,7 +8,8 @@ from gevent import monkey
from gevent import greenlet, timeout
import gevent
GREENLET_MASK_STATE = dict()
logger = logging.getLogger(__name__)
class KillMask:
......@@ -15,10 +17,10 @@ class KillMask:
call on the greenlet in which the KillMask context is entered,
will be delayed until context exit, except for `gevent.Timeout`.
Upon exit, only the last capture exception is re-raised.
Upon exit, only the last captured exception is re-raised.
Optionally we can set a limit to the number of `kill` calls
will be delayed.
that will be delayed.
Warning: this does not delay interrupts.
"""
......@@ -30,22 +32,33 @@ class KillMask:
== 0 no kill is delayed
< 0 unlimited
"""
self.__greenlet = gevent.getcurrent()
self.__masked_kill_nb = masked_kill_nb
self.__allowed_kills = masked_kill_nb
self.__allowed_capture_nb = masked_kill_nb
self.__last_captured_exception = None
@property
def _bliss_greenlet(self):
glt = gevent.getcurrent()
if isinstance(glt, BlissGreenlet):
return glt
elif glt.parent is not None:
logger.warning("KillMask will not work in the current greenlet: %s", glt)
return None
def __enter__(self):
self.__allowed_kills = self.__masked_kill_nb
glt = self._bliss_greenlet
if glt is None:
return
self.__allowed_capture_nb = self.__masked_kill_nb
self.__last_captured_exception = None
GREENLET_MASK_STATE.setdefault(self.__greenlet, set()).add(self)
glt.kill_masks.add(self)
def __exit__(self, exc_type, value, traceback):
GREENLET_MASK_STATE[self.__greenlet].remove(self)
if GREENLET_MASK_STATE[self.__greenlet]:
glt = self._bliss_greenlet
if glt is None:
return
GREENLET_MASK_STATE.pop(self.__greenlet)
if self.__last_captured_exception is not None:
glt.kill_masks.remove(self)
if not glt.kill_masks and self.__last_captured_exception is not None:
raise self.__last_captured_exception
@property
......@@ -53,12 +66,12 @@ class KillMask:
return self.__last_captured_exception
def capture_exception(self, exception):
capture = bool(self.__allowed_kills)
capture = bool(self.__allowed_capture_nb)
if capture:
self.__last_captured_exception = exception
self.__allowed_capture_nb -= 1
else:
self.__last_captured_exception = None
self.__allowed_kills -= 1
return capture
......@@ -67,16 +80,15 @@ def AllowKill():
"""
This will unmask the kill protection for the current greenlet.
"""
current_greenlet = gevent.getcurrent()
kill_masks = GREENLET_MASK_STATE.pop(current_greenlet, set())
try:
for kill_mask in kill_masks:
if kill_mask.last_captured_exception:
raise kill_mask.last_captured_exception
glt = gevent.getcurrent()
if isinstance(glt, BlissGreenlet):
with glt.disable_kill_masks() as kill_masks:
for kill_mask in kill_masks:
if kill_mask.last_captured_exception:
raise kill_mask.last_captured_exception
yield
else:
yield
finally:
if kill_masks:
GREENLET_MASK_STATE[current_greenlet] = kill_masks
def protect_from_kill(method):
......@@ -105,14 +117,33 @@ class BlissGreenlet(_GeventGreenlet):
"""KillMask can only work when entered in a greenlet of type `BlissGreenlet`
"""
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self.__kill_masks = set()
@property
def kill_masks(self):
return self.__kill_masks
@contextmanager
def disable_kill_masks(self):
kill_masks = self.__kill_masks
self.__kill_masks = set()
try:
yield kill_masks
finally:
self.__kill_masks = kill_masks
def throw(self, exception):
# This is executed in the Hub which is the reason
# we cannot use gevent.local.local to store the
# kill masks for each greenlet.
if isinstance(exception, _GeventTimeout):
return super().throw(exception)
kill_masks = GREENLET_MASK_STATE.get(self)
if kill_masks:
if self.__kill_masks:
captured_in_all_masks = True
for kill_mask in list(kill_masks):
for kill_mask in self.__kill_masks:
captured_in_all_masks &= kill_mask.capture_exception(exception)
if captured_in_all_masks:
return
......@@ -129,12 +160,13 @@ class BlissGreenlet(_GeventGreenlet):
class BlissTimeout(_GeventTimeout):
"""KillMask can only work when timeouts are of type `BlissGreenlet`
"""KillMask can only work when timeouts are of type `BlissTimeout`
"""
def _on_expiration(self, prev_greenlet, ex):
if isinstance(prev_greenlet, BlissGreenlet):
# Make sure the exception is not captured
# Make sure the exception is not captured by
# a KillMask
super(BlissGreenlet, prev_greenlet).throw(ex)
else:
prev_greenlet.throw(ex)
......@@ -145,12 +177,13 @@ def patch_gevent():
monkey.patch_all(thread=False)
# For KillMask
gevent.spawn = BlissGreenlet.spawn
gevent.spawn_later = BlissGreenlet.spawn_later
timeout.Timeout = BlissTimeout
gevent.Timeout = BlissTimeout
# For backward compatitibilty
# For backward compatibility
Greenlet = BlissGreenlet
Timeout = BlissTimeout
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment