Skip to content

Commit

Permalink
Merge branch 'master' into hash-change-guards
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose authored Mar 2, 2024
2 parents 4e1d4a8 + 0e66136 commit b94f185
Show file tree
Hide file tree
Showing 10 changed files with 538 additions and 94 deletions.
4 changes: 3 additions & 1 deletion pydra/engine/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ def make_klass(spec):
**kwargs,
)
checker_label = f"'{name}' field of {spec.name}"
type_checker = TypeParser[newfield.type](newfield.type, label=checker_label)
type_checker = TypeParser[newfield.type](
newfield.type, label=checker_label, superclass_auto_cast=True
)
if newfield.type in (MultiInputObj, MultiInputFile):
converter = attr.converters.pipe(ensure_list, type_checker)
elif newfield.type in (MultiOutputObj, MultiOutputFile):
Expand Down
11 changes: 7 additions & 4 deletions pydra/engine/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,16 +449,19 @@ def collect_additional_outputs(self, inputs, output_dir, outputs):
),
):
raise TypeError(
f"Support for {fld.type} type, required for {fld.name} in {self}, "
f"Support for {fld.type} type, required for '{fld.name}' in {self}, "
"has not been implemented in collect_additional_output"
)
# assuming that field should have either default or metadata, but not both
input_value = getattr(inputs, fld.name, attr.NOTHING)
if input_value is not attr.NOTHING:
if TypeParser.contains_type(FileSet, fld.type):
label = f"output field '{fld.name}' of {self}"
input_value = TypeParser(fld.type, label=label).coerce(input_value)
additional_out[fld.name] = input_value
if input_value is not False:
label = f"output field '{fld.name}' of {self}"
input_value = TypeParser(fld.type, label=label).coerce(
input_value
)
additional_out[fld.name] = input_value
elif (
fld.default is None or fld.default == attr.NOTHING
) and not fld.metadata: # TODO: is it right?
Expand Down
29 changes: 20 additions & 9 deletions pydra/engine/submitter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Handle execution backends."""

import asyncio
import typing as ty
import pickle
from uuid import uuid4
from .workers import WORKERS
from .workers import Worker, WORKERS
from .core import is_workflow
from .helpers import get_open_loop, load_and_run_async

Expand All @@ -16,24 +17,34 @@
class Submitter:
"""Send a task to the execution backend."""

def __init__(self, plugin="cf", **kwargs):
def __init__(self, plugin: ty.Union[str, ty.Type[Worker]] = "cf", **kwargs):
"""
Initialize task submission.
Parameters
----------
plugin : :obj:`str`
The identifier of the execution backend.
plugin : :obj:`str` or :obj:`ty.Type[pydra.engine.core.Worker]`
Either the identifier of the execution backend or the worker class itself.
Default is ``cf`` (Concurrent Futures).
**kwargs
Additional keyword arguments to pass to the worker.
"""
self.loop = get_open_loop()
self._own_loop = not self.loop.is_running()
self.plugin = plugin
try:
self.worker = WORKERS[self.plugin](**kwargs)
except KeyError:
raise NotImplementedError(f"No worker for {self.plugin}")
if isinstance(plugin, str):
self.plugin = plugin
try:
worker_cls = WORKERS[self.plugin]
except KeyError:
raise NotImplementedError(f"No worker for '{self.plugin}' plugin")
else:
try:
self.plugin = plugin.plugin_name
except AttributeError:
raise ValueError("Worker class must have a 'plugin_name' str attribute")
worker_cls = plugin
self.worker = worker_cls(**kwargs)
self.worker.loop = self.loop

def __call__(self, runnable, cache_locations=None, rerun=False, environment=None):
Expand Down
16 changes: 1 addition & 15 deletions pydra/engine/tests/test_node_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,7 @@ def test_task_init_3a(


def test_task_init_4():
"""task with interface and inputs. splitter set using split method"""
nn = fun_addtwo(name="NA")
nn.split(splitter="a", a=[3, 5])
assert np.allclose(nn.inputs.a, [3, 5])

assert nn.state.splitter == "NA.a"
assert nn.state.splitter_rpn == ["NA.a"]

nn.state.prepare_states(nn.inputs)
assert nn.state.states_ind == [{"NA.a": 0}, {"NA.a": 1}]
assert nn.state.states_val == [{"NA.a": 3}, {"NA.a": 5}]


def test_task_init_4a():
"""task with a splitter and inputs set in the split method"""
"""task with interface splitter and inputs set in the split method"""
nn = fun_addtwo(name="NA")
nn.split(splitter="a", a=[3, 5])
assert np.allclose(nn.inputs.a, [3, 5])
Expand Down
96 changes: 96 additions & 0 deletions pydra/engine/tests/test_shelltask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4347,6 +4347,102 @@ def change_name(file):
# res = shelly(plugin="cf")


def test_shell_cmd_optional_output_file1(tmp_path):
"""
Test to see that 'unused' doesn't complain about not having an output passed to it
"""
my_cp_spec = SpecInfo(
name="Input",
fields=[
(
"input",
attr.ib(
type=File, metadata={"argstr": "", "help_string": "input file"}
),
),
(
"output",
attr.ib(
type=Path,
metadata={
"argstr": "",
"output_file_template": "out.txt",
"help_string": "output file",
},
),
),
(
"unused",
attr.ib(
type=ty.Union[Path, bool],
default=False,
metadata={
"argstr": "--not-used",
"output_file_template": "out.txt",
"help_string": "dummy output",
},
),
),
],
bases=(ShellSpec,),
)

my_cp = ShellCommandTask(
name="my_cp",
executable="cp",
input_spec=my_cp_spec,
)
file1 = tmp_path / "file1.txt"
file1.write_text("foo")
result = my_cp(input=file1, unused=False)
assert result.output.output.fspath.read_text() == "foo"


def test_shell_cmd_optional_output_file2(tmp_path):
"""
Test to see that 'unused' doesn't complain about not having an output passed to it
"""
my_cp_spec = SpecInfo(
name="Input",
fields=[
(
"input",
attr.ib(
type=File, metadata={"argstr": "", "help_string": "input file"}
),
),
(
"output",
attr.ib(
type=ty.Union[Path, bool],
default=False,
metadata={
"argstr": "",
"output_file_template": "out.txt",
"help_string": "dummy output",
},
),
),
],
bases=(ShellSpec,),
)

my_cp = ShellCommandTask(
name="my_cp",
executable="cp",
input_spec=my_cp_spec,
)
file1 = tmp_path / "file1.txt"
file1.write_text("foo")
result = my_cp(input=file1, output=True)
assert result.output.output.fspath.read_text() == "foo"

file2 = tmp_path / "file2.txt"
file2.write_text("bar")
with pytest.raises(RuntimeError):
my_cp(input=file2, output=False)


def test_shell_cmd_non_existing_outputs_1(tmp_path):
"""Checking that non existing output files do not return a phantom path,
but return NOTHING instead"""
Expand Down
67 changes: 66 additions & 1 deletion pydra/engine/tests/test_submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import attrs
import typing as ty
from random import randint
import os
from unittest.mock import patch
import pytest
from fileformats.generic import Directory
from .utils import (
Expand All @@ -15,8 +17,9 @@
gen_basic_wf_with_threadcount,
gen_basic_wf_with_threadcount_concurrent,
)
from ..core import Workflow
from ..core import Workflow, TaskBase
from ..submitter import Submitter
from ..workers import SerialWorker
from ... import mark
from pathlib import Path
from datetime import datetime
Expand Down Expand Up @@ -665,3 +668,65 @@ def to_tuple(x, y):
):
with Submitter("cf") as sub:
result = sub(wf)

@mark.task
def to_tuple(x, y):
return (x, y)


class BYOAddVarWorker(SerialWorker):
"""A dummy worker that adds 1 to the output of the task"""

plugin_name = "byo_add_env_var"

def __init__(self, add_var, **kwargs):
super().__init__(**kwargs)
self.add_var = add_var

async def exec_serial(self, runnable, rerun=False, environment=None):
if isinstance(runnable, TaskBase):
with patch.dict(os.environ, {"BYO_ADD_VAR": str(self.add_var)}):
result = runnable._run(rerun, environment=environment)
return result
else: # it could be tuple that includes pickle files with tasks and inputs
return super().exec_serial(runnable, rerun, environment)


@mark.task
def add_env_var_task(x: int) -> int:
return x + int(os.environ.get("BYO_ADD_VAR", 0))


def test_byo_worker():

task1 = add_env_var_task(x=1)

with Submitter(plugin=BYOAddVarWorker, add_var=10) as sub:
assert sub.plugin == "byo_add_env_var"
result = task1(submitter=sub)

assert result.output.out == 11

task2 = add_env_var_task(x=2)

with Submitter(plugin="serial") as sub:
result = task2(submitter=sub)

assert result.output.out == 2


def test_bad_builtin_worker():

with pytest.raises(NotImplementedError, match="No worker for 'bad-worker' plugin"):
Submitter(plugin="bad-worker")


def test_bad_byo_worker():

class BadWorker:
pass

with pytest.raises(
ValueError, match="Worker class must have a 'plugin_name' str attribute"
):
Submitter(plugin=BadWorker)
Loading

0 comments on commit b94f185

Please sign in to comment.