Commit a3e4130e authored by Wout De Nolf's avatar Wout De Nolf
Browse files

esrftaskgraph: required positional inputs

parent 36a1a790
......@@ -26,6 +26,7 @@ class Task(Registered, hashing.UniversalHashable, register=False):
_INPUT_NAMES = set()
_OPTIONAL_INPUT_NAMES = set()
_OUTPUT_NAMES = set()
_N_REQUIRED_POSITIONAL_INPUTS = 0
def __init__(self, inputs=None, varinfo=None):
"""The named arguments are inputs and Variable configuration"""
......@@ -53,6 +54,13 @@ class Task(Registered, hashing.UniversalHashable, register=False):
# The output hash will update dynamically if any of the input
# variables change
self._inputs = VariableContainer(value=inputs, varinfo=varinfo)
nargs = self._inputs.n_positional_variables
nrequiredargs = self._N_REQUIRED_POSITIONAL_INPUTS
if nargs < nrequiredargs:
raise ValueError(
f"Missing positional inputs for {type(self)}: {nrequiredargs} required but only {nargs} provided"
)
self._outputs = VariableContainer(
value=ovars,
uhash=self._inputs,
......@@ -71,6 +79,7 @@ class Task(Registered, hashing.UniversalHashable, register=False):
input_names=tuple(),
optional_input_names=tuple(),
output_names=tuple(),
n_required_positional_inputs=0,
**kwargs,
):
super().__init_subclass__(**kwargs)
......@@ -94,6 +103,7 @@ class Task(Registered, hashing.UniversalHashable, register=False):
optional_input_names
)
subclass._OUTPUT_NAMES = subclass._OUTPUT_NAMES | set(output_names)
subclass._N_REQUIRED_POSITIONAL_INPUTS = n_required_positional_inputs
@staticmethod
def _reserved_variable_names():
......@@ -158,7 +168,7 @@ class Task(Registered, hashing.UniversalHashable, register=False):
@property
def npositional_inputs(self):
return self._inputs.npositional_variable_values
return self._inputs.n_positional_variables
@property
def output_variables(self):
......
......@@ -254,6 +254,10 @@ class VariableContainer(Mapping, Variable):
self._set_item(i, self.MISSING_DATA)
self.__npositional_vars = nafter
@property
def n_positional_variables(self):
return self.__npositional_vars
def _create_variable(self, key, value):
if isinstance(value, Variable):
return value
......@@ -339,10 +343,6 @@ class VariableContainer(Mapping, Variable):
def named_variable_values(self):
return {k: v.value for k, v in self.items() if isinstance(k, str)}
@property
def npositional_variable_values(self):
return self.__npositional_vars
@property
def positional_variable_values(self):
values = [self.MISSING_DATA] * self.__npositional_vars
......
......@@ -111,3 +111,11 @@ def test_task_storage(tmpdir, varinfo):
assert task.outputs.result == 13
expected += [{"result": str(task.output_variables["result"].uhash)}, 13]
assert_storage(tmpdir, expected)
def test_task_required_positional_inputs():
class MyTask(Task, n_required_positional_inputs=1):
pass
with pytest.raises(ValueError):
MyTask()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment