Commit a988379f authored by Wout De Nolf's avatar Wout De Nolf
Browse files

make Task.inputs namespace read-only

parent d1aca94b
......@@ -2,6 +2,7 @@ from esrftaskgraph import hashing
from esrftaskgraph.variable import VARINFO
from esrftaskgraph.variable import VariableContainer
from esrftaskgraph.variable import VariableContainerNamespace
from esrftaskgraph.variable import ReadOnlyVariableContainerNamespace
from esrftaskgraph.registration import Registered
......@@ -53,7 +54,7 @@ class Task(Registered, hashing.UniversalHashable, register=False):
value=ovars, uhash=self._inputs, uhash_nonce=self.class_nonce(), **varkw
)
self._user_inputs = VariableContainerNamespace(self._inputs)
self._user_inputs = ReadOnlyVariableContainerNamespace(self._inputs)
self._user_outputs = VariableContainerNamespace(self._outputs)
# The task class has the same hash as its output
......
......@@ -326,7 +326,11 @@ class MutableVariableContainer(VariableContainer, MutableMapping):
del self.value[name]
class UnknownVariableError(AttributeError):
class UnknownVariableError(RuntimeError):
pass
class ReadOnlyVariableError(RuntimeError):
pass
......@@ -358,3 +362,11 @@ class VariableContainerNamespace:
return self._container[name]
except (KeyError, TypeError):
raise UnknownVariableError(name)
class ReadOnlyVariableContainerNamespace(VariableContainerNamespace):
def __setattr__(self, name, value):
if name in self._reserved_variable_names():
super(VariableContainerNamespace, self).__setattr__(name, value)
else:
raise ReadOnlyVariableError(name)
......@@ -36,6 +36,12 @@ def test_task_missing_input(variable_kwargs):
SumTask(**variable_kwargs)
def test_task_readonly_input(variable_kwargs):
task = SumTask(a=10, **variable_kwargs)
with pytest.raises(RuntimeError):
task.inputs.a = 10
def test_task_optional_input(tmpdir, variable_kwargs):
task = SumTask(a=10, **variable_kwargs)
assert not task.done
......@@ -52,7 +58,7 @@ def test_task_uhash(variable_kwargs):
assert task.uhash == task.output_variables.uhash
assert task.uhash != task.input_variables.uhash
task.inputs.a += 1
task.input_variables["a"].value += 1
assert task.uhash != uhash
assert task.uhash == task.output_variables.uhash
assert task.uhash != task.input_variables.uhash
......@@ -60,7 +66,7 @@ def test_task_uhash(variable_kwargs):
task.run()
assert task.done
task.inputs.a += 1
task.input_variables["a"].value += 1
assert task.uhash != uhash
assert task.uhash == task.output_variables.uhash
assert task.uhash != task.input_variables.uhash
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment