Commit 36a1a790 authored by Wout De Nolf's avatar Wout De Nolf
Browse files

esrftaskgraph: restrict variable identifiers to str and int

parent 7fc82c0d
......@@ -193,6 +193,8 @@ class Variable(hashing.UniversalHashable):
class VariableContainer(Mapping, Variable):
"""An immutable mapping of variable identifiers (str or int) to variables (Variable)."""
def __init__(self, **kw):
value = kw.pop("value", None)
self.__varparams = kw
......@@ -201,8 +203,8 @@ class VariableContainer(Mapping, Variable):
if value:
self._update(value)
def __getitem__(self, name):
return self.value[name]
def __getitem__(self, key):
return self.value[key]
def _update(self, value):
if isinstance(value, Mapping):
......@@ -220,25 +222,39 @@ class VariableContainer(Mapping, Variable):
)
self._set_item(*tpl)
def _set_item(self, name, value):
self._initialize_container()
self._fill_missing_positions(name)
self.value[name] = self._create_variable(name, value)
def _set_item(self, key, value):
key = self._parse_variable_name(key)
if isinstance(key, int):
self._fill_missing_positions(key)
if not self.container_has_value:
self.value = dict()
self.value[key] = self._create_variable(key, value)
def _fill_missing_positions(self, name):
if not self.is_positional_variable(name):
return
def _parse_variable_name(self, key):
"""Variables are identified by a `str` or an `int`. A key like "1" will
be converted to an `int` (e.g. json dump converts `int` to `str`).
"""
if isinstance(key, str):
if key.isdigit():
key = int(key)
if isinstance(key, Integral):
key = int(key)
if key < 0:
raise ValueError("Negative argument positions are not allowed")
elif not isinstance(key, str):
raise TypeError(
f"Variable {key} must be a string or positive integer", type(key)
)
return key
def _fill_missing_positions(self, pos):
nbefore = self.__npositional_vars
nafter = max(nbefore, name + 1)
nafter = max(nbefore, pos + 1)
for i in range(nbefore, nafter - 1):
self._set_item(i, self.MISSING_DATA)
self.__npositional_vars = nafter
@staticmethod
def is_positional_variable(name):
return isinstance(name, Integral) and name >= 0
def _create_variable(self, name, value):
def _create_variable(self, key, value):
if isinstance(value, Variable):
return value
varparams = dict(self.__varparams)
......@@ -248,16 +264,9 @@ class VariableContainer(Mapping, Variable):
else:
varparams["value"] = value
uhash_nonce = varparams.pop("uhash_nonce", None)
varparams["uhash_nonce"] = uhash_nonce, name
varparams["uhash_nonce"] = uhash_nonce, key
return Variable(**varparams)
def _initialize_container(self):
if (
not self.container_has_runtime_value
and not self.container_has_persistent_value
):
self.value = dict()
def __iter__(self):
adict = self.value
if isinstance(adict, dict):
......@@ -309,6 +318,10 @@ class VariableContainer(Mapping, Variable):
else:
return False
@property
def container_has_value(self):
return self.container_has_runtime_value or self.container_has_persistent_value
def force_non_existing(self):
super().force_non_existing()
for v in self.values():
......@@ -324,9 +337,7 @@ class VariableContainer(Mapping, Variable):
@property
def named_variable_values(self):
return {
k: v.value for k, v in self.items() if not self.is_positional_variable(k)
}
return {k: v.value for k, v in self.items() if isinstance(k, str)}
@property
def npositional_variable_values(self):
......@@ -336,7 +347,7 @@ class VariableContainer(Mapping, Variable):
def positional_variable_values(self):
values = [self.MISSING_DATA] * self.__npositional_vars
for i, var in self.items():
if self.is_positional_variable(i):
if isinstance(i, int):
values[i] = var.value
return tuple(values)
......@@ -348,13 +359,15 @@ class VariableContainer(Mapping, Variable):
class MutableVariableContainer(VariableContainer, MutableMapping):
def __setitem__(self, name, value):
self._set_item(name, value)
"""An mutable mapping of variable identifiers (str or int) to variables (Variable)."""
def __setitem__(self, key, value):
self._set_item(key, value)
def __delitem__(self, name):
def __delitem__(self, key):
adict = self.value
if isinstance(adict, dict):
del self.value[name]
del self.value[key]
class MissingVariableError(RuntimeError):
......@@ -366,7 +379,7 @@ class ReadOnlyVariableError(RuntimeError):
class ReadOnlyVariableContainerNamespace:
"""Expose getting variable values through attributes"""
"""Expose getting variable values through attributes and indexing"""
def __init__(self, container):
self._container = container
......@@ -379,31 +392,31 @@ class ReadOnlyVariableContainerNamespace:
cls._RESERVED_VARIABLE_NAMES = set(dir(cls)) | {"_container"}
return cls._RESERVED_VARIABLE_NAMES
def __setattr__(self, name, value):
if name == "_container":
super().__setattr__(name, value)
def __setattr__(self, attrname, value):
if attrname == "_container":
super().__setattr__(attrname, value)
else:
self._get_variable(name)
raise ReadOnlyVariableError(name)
self._get_variable(attrname)
raise ReadOnlyVariableError(attrname)
def __getattr__(self, name):
return self[name]
def __getattr__(self, attrname):
return self[attrname]
def __getitem__(self, name):
return self._get_variable(name).value
def __getitem__(self, key):
return self._get_variable(key).value
def _get_variable(self, name):
def _get_variable(self, key):
try:
return self._container[name]
return self._container[key]
except (KeyError, TypeError):
raise MissingVariableError(name)
raise MissingVariableError(key)
class VariableContainerNamespace(ReadOnlyVariableContainerNamespace):
"""Expose getting/setting variable values through attributes"""
"""Expose getting and setting variable values through attributes and indexing"""
def __setattr__(self, name, value):
if name == "_container":
super().__setattr__(name, value)
def __setattr__(self, attrname, value):
if attrname == "_container":
super().__setattr__(attrname, value)
else:
self._get_variable(name).value = value
self._get_variable(attrname).value = value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment