Commit 9260caf5 authored by Wout De Nolf's avatar Wout De Nolf
Browse files

esrftaskgraph: introduce positional Task inputs

parent c9da8e28
......@@ -4,10 +4,11 @@ from esrftaskgraph.utils import import_method
class MethodExecutorTask(Task, input_names=["method"], output_names=["return_value"]):
def run(self):
method_kwargs = self.input_values
fullname = method_kwargs.pop("method")
kwargs = self.named_input_values
args = self.positional_input_values
fullname = kwargs.pop("method")
method = import_method(fullname)
result = method(**method_kwargs)
result = method(*args, **kwargs)
self.outputs.return_value = result
......@@ -148,6 +148,18 @@ class Task(Registered, hashing.UniversalHashable, register=False):
def input_values(self):
return self._inputs.variable_values
@property
def named_input_values(self):
return self._inputs.named_variable_values
@property
def positional_input_values(self):
return self._inputs.positional_variable_values
@property
def npositional_inputs(self):
return self._inputs.npositional_variable_values
@property
def output_variables(self):
return self._outputs
......
......@@ -2,6 +2,7 @@ import os
import string
import random
import json
from numbers import Integral
from collections.abc import Mapping, MutableMapping, Iterable, Sequence
from contextlib import contextmanager
from esrftaskgraph import hashing
......@@ -201,6 +202,7 @@ class VariableContainer(Mapping, Variable):
def __init__(self, **kw):
value = kw.pop("value", None)
self.__varparams = kw
self.__npositional_vars = 0
super().__init__(**kw)
if value:
self._update(value)
......@@ -225,21 +227,39 @@ class VariableContainer(Mapping, Variable):
self._set_item(*tpl)
def _set_item(self, name, value):
self._initialize_container()
self._fill_missing_positions(name)
self.value[name] = self._create_variable(name, value)
def _fill_missing_positions(self, name):
if not self.is_positional_variable(name):
return
nbefore = self.__npositional_vars
nafter = max(nbefore, name + 1)
for i in range(nbefore, nafter - 1):
self._set_item(i, self.MISSING_DATA)
self.__npositional_vars = nafter
@staticmethod
def is_positional_variable(name):
return isinstance(name, Integral) and name >= 0
def _create_variable(self, name, value):
if isinstance(value, Variable):
var = value
return value
varparams = dict(self.__varparams)
if isinstance(value, hashing.UniversalHash):
varparams["uhash"] = value
varparams["uhash_nonce"] = None
else:
varparams = dict(self.__varparams)
if isinstance(value, hashing.UniversalHash):
varparams["uhash"] = value
varparams["uhash_nonce"] = None
else:
varparams["value"] = value
uhash_nonce = varparams.pop("uhash_nonce", None)
varparams["uhash_nonce"] = uhash_nonce, name
var = Variable(**varparams)
varparams["value"] = value
uhash_nonce = varparams.pop("uhash_nonce", None)
varparams["uhash_nonce"] = uhash_nonce, name
return Variable(**varparams)
def _initialize_container(self):
if not self.container_available and not self.container_exists:
self.value = dict()
self.value[name] = var
def __iter__(self):
adict = self.value
......@@ -309,6 +329,24 @@ class VariableContainer(Mapping, Variable):
def variable_values(self):
return {k: v.value for k, v in self.items()}
@property
def named_variable_values(self):
return {
k: v.value for k, v in self.items() if not self.is_positional_variable(k)
}
@property
def npositional_variable_values(self):
return self.__npositional_vars
@property
def positional_variable_values(self):
values = [self.MISSING_DATA] * self.__npositional_vars
for i, var in self.items():
if self.is_positional_variable(i):
values[i] = var.value
return tuple(values)
def update_values(self, items):
if isinstance(items, Mapping):
items = items.items()
......
......@@ -6,8 +6,12 @@ from esrftaskgraph.task import Task
from tasklib.tasks import SumTask
def mymethod(a=0, b=0):
return {"result": a + b}
def mymethod1(a=0, b=0):
return a + b
def mymethod2(*args):
return sum(args)
def myppfmethod(a=0, b=0, **kw):
......@@ -126,12 +130,21 @@ def test_task_storage(tmpdir, varinfo):
def test_method_task(varinfo):
task = Task.instantiate(
"MethodExecutorTask",
inputs={"method": qualname(mymethod), "a": 3, "b": 5},
inputs={"method": qualname(mymethod1), "a": 3, "b": 5},
varinfo=varinfo,
)
task.execute()
assert task.done
assert task.output_values == {"return_value": 8}
task = Task.instantiate(
"MethodExecutorTask",
inputs={"method": qualname(mymethod2), 0: 3, 1: 5},
varinfo=varinfo,
)
task.execute()
assert task.done
assert task.output_values == {"return_value": {"result": 8}}
assert task.output_values == {"return_value": 8}
def test_ppfmethod_task(varinfo):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment