Skip to content

Commit

Permalink
Merge pull request #696 from tclose/subclass-permissive-typing
Browse files Browse the repository at this point in the history
Permit superclass to subclass lazy typing
  • Loading branch information
djarecka authored Feb 27, 2024
2 parents 1720ba6 + 27e7fb8 commit 1858668
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 62 deletions.
4 changes: 3 additions & 1 deletion pydra/engine/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ def make_klass(spec):
**kwargs,
)
checker_label = f"'{name}' field of {spec.name}"
type_checker = TypeParser[newfield.type](newfield.type, label=checker_label)
type_checker = TypeParser[newfield.type](
newfield.type, label=checker_label, superclass_auto_cast=True
)
if newfield.type in (MultiInputObj, MultiInputFile):
converter = attr.converters.pipe(ensure_list, type_checker)
elif newfield.type in (MultiOutputObj, MultiOutputFile):
Expand Down
2 changes: 1 addition & 1 deletion pydra/engine/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def collect_additional_outputs(self, inputs, output_dir, outputs):
),
):
raise TypeError(
f"Support for {fld.type} type, required for {fld.name} in {self}, "
f"Support for {fld.type} type, required for '{fld.name}' in {self}, "
"has not been implemented in collect_additional_output"
)
# assuming that field should have either default or metadata, but not both
Expand Down
16 changes: 1 addition & 15 deletions pydra/engine/tests/test_node_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,7 @@ def test_task_init_3a(


def test_task_init_4():
"""task with interface and inputs. splitter set using split method"""
nn = fun_addtwo(name="NA")
nn.split(splitter="a", a=[3, 5])
assert np.allclose(nn.inputs.a, [3, 5])

assert nn.state.splitter == "NA.a"
assert nn.state.splitter_rpn == ["NA.a"]

nn.state.prepare_states(nn.inputs)
assert nn.state.states_ind == [{"NA.a": 0}, {"NA.a": 1}]
assert nn.state.states_val == [{"NA.a": 3}, {"NA.a": 5}]


def test_task_init_4a():
"""task with a splitter and inputs set in the split method"""
"""task with interface splitter and inputs set in the split method"""
nn = fun_addtwo(name="NA")
nn.split(splitter="a", a=[3, 5])
assert np.allclose(nn.inputs.a, [3, 5])
Expand Down
185 changes: 170 additions & 15 deletions pydra/utils/tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import itertools
import sys
import typing as ty
from pathlib import Path
import tempfile
Expand All @@ -8,13 +9,16 @@
from ...engine.specs import File, LazyOutField
from ..typing import TypeParser
from pydra import Workflow
from fileformats.application import Json
from fileformats.application import Json, Yaml, Xml
from .utils import (
generic_func_task,
GenericShellTask,
specific_func_task,
SpecificShellTask,
other_specific_func_task,
OtherSpecificShellTask,
MyFormatX,
MyOtherFormatX,
MyHeader,
)

Expand Down Expand Up @@ -152,8 +156,12 @@ def test_type_check_nested6():


def test_type_check_nested7():
TypeParser(ty.Tuple[float, float, float])(lz(ty.List[int]))


def test_type_check_nested7a():
with pytest.raises(TypeError, match="Wrong number of type arguments"):
TypeParser(ty.Tuple[float, float, float])(lz(ty.List[int]))
TypeParser(ty.Tuple[float, float, float])(lz(ty.Tuple[int]))


def test_type_check_nested8():
Expand All @@ -164,6 +172,18 @@ def test_type_check_nested8():
)(lz(ty.List[float]))


def test_type_check_permit_superclass():
# Typical case as Json is subclass of File
TypeParser(ty.List[File])(lz(ty.List[Json]))
# Permissive super class, as File is superclass of Json
TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[File]))
with pytest.raises(TypeError, match="Cannot coerce"):
TypeParser(ty.List[Json], superclass_auto_cast=False)(lz(ty.List[File]))
# Fails because Yaml is neither sub or super class of Json
with pytest.raises(TypeError, match="Cannot coerce"):
TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[Yaml]))


def test_type_check_fail1():
with pytest.raises(TypeError, match="Wrong number of type arguments in tuple"):
TypeParser(ty.Tuple[int, int, int])(lz(ty.Tuple[float, float, float, float]))
Expand Down Expand Up @@ -490,14 +510,29 @@ def test_matches_type_tuple():
assert not TypeParser.matches_type(ty.Tuple[int], ty.Tuple[int, int])


def test_matches_type_tuple_ellipsis():
def test_matches_type_tuple_ellipsis1():
assert TypeParser.matches_type(ty.Tuple[int], ty.Tuple[int, ...])


def test_matches_type_tuple_ellipsis2():
assert TypeParser.matches_type(ty.Tuple[int, int], ty.Tuple[int, ...])


def test_matches_type_tuple_ellipsis3():
assert not TypeParser.matches_type(ty.Tuple[int, float], ty.Tuple[int, ...])
assert not TypeParser.matches_type(ty.Tuple[int, ...], ty.Tuple[int])


def test_matches_type_tuple_ellipsis4():
assert TypeParser.matches_type(ty.Tuple[int, ...], ty.Tuple[int])


def test_matches_type_tuple_ellipsis5():
assert TypeParser.matches_type(
ty.Tuple[int], ty.List[int], coercible=[(tuple, list)]
)


def test_matches_type_tuple_ellipsis6():
assert TypeParser.matches_type(
ty.Tuple[int, ...], ty.List[int], coercible=[(tuple, list)]
)
Expand Down Expand Up @@ -538,7 +573,17 @@ def specific_task(request):
assert False


def test_typing_cast(tmp_path, generic_task, specific_task):
@pytest.fixture(params=["func", "shell"])
def other_specific_task(request):
if request.param == "func":
return other_specific_func_task
elif request.param == "shell":
return OtherSpecificShellTask
else:
assert False


def test_typing_implicit_cast_from_super(tmp_path, generic_task, specific_task):
"""Check the casting of lazy fields and whether specific file-sets can be recovered
from generic `File` classes"""

Expand All @@ -562,33 +607,86 @@ def test_typing_cast(tmp_path, generic_task, specific_task):
)
)

wf.add(
specific_task(
in_file=wf.generic.lzout.out,
name="specific2",
)
)

wf.set_output(
[
("out_file", wf.specific2.lzout.out),
]
)

in_file = MyFormatX.sample()

result = wf(in_file=in_file, plugin="serial")

out_file: MyFormatX = result.output.out_file
assert type(out_file) is MyFormatX
assert out_file.parent != in_file.parent
assert type(out_file.header) is MyHeader
assert out_file.header.parent != in_file.header.parent


def test_typing_cast(tmp_path, specific_task, other_specific_task):
"""Check the casting of lazy fields and whether specific file-sets can be recovered
from generic `File` classes"""

wf = Workflow(
name="test",
input_spec={"in_file": MyFormatX},
output_spec={"out_file": MyFormatX},
)

wf.add(
specific_task(
in_file=wf.lzin.in_file,
name="entry",
)
)

with pytest.raises(TypeError, match="Cannot coerce"):
# No cast of generic task output to MyFormatX
wf.add( # Generic task
other_specific_task(
in_file=wf.entry.lzout.out,
name="inner",
)
)

wf.add( # Generic task
other_specific_task(
in_file=wf.entry.lzout.out.cast(MyOtherFormatX),
name="inner",
)
)

with pytest.raises(TypeError, match="Cannot coerce"):
# No cast of generic task output to MyFormatX
wf.add(
specific_task(
in_file=wf.generic.lzout.out,
name="specific2",
in_file=wf.inner.lzout.out,
name="exit",
)
)

wf.add(
specific_task(
in_file=wf.generic.lzout.out.cast(MyFormatX),
name="specific2",
in_file=wf.inner.lzout.out.cast(MyFormatX),
name="exit",
)
)

wf.set_output(
[
("out_file", wf.specific2.lzout.out),
("out_file", wf.exit.lzout.out),
]
)

my_fspath = tmp_path / "in_file.my"
hdr_fspath = tmp_path / "in_file.hdr"
my_fspath.write_text("my-format")
hdr_fspath.write_text("my-header")
in_file = MyFormatX([my_fspath, hdr_fspath])
in_file = MyFormatX.sample()

result = wf(in_file=in_file, plugin="serial")

Expand All @@ -611,6 +709,63 @@ def test_type_is_subclass3():
assert TypeParser.is_subclass(ty.Type[Json], ty.Type[File])


def test_union_is_subclass1():
assert TypeParser.is_subclass(ty.Union[Json, Yaml], ty.Union[Json, Yaml, Xml])


def test_union_is_subclass2():
assert not TypeParser.is_subclass(ty.Union[Json, Yaml, Xml], ty.Union[Json, Yaml])


def test_union_is_subclass3():
assert TypeParser.is_subclass(Json, ty.Union[Json, Yaml])


def test_union_is_subclass4():
assert not TypeParser.is_subclass(ty.Union[Json, Yaml], Json)


def test_generic_is_subclass1():
assert TypeParser.is_subclass(ty.List[int], list)


def test_generic_is_subclass2():
assert not TypeParser.is_subclass(list, ty.List[int])


def test_generic_is_subclass3():
assert not TypeParser.is_subclass(ty.List[float], ty.List[int])


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Cannot subscript tuple in < Py3.9"
)
def test_generic_is_subclass4():
class MyTuple(tuple):
pass

class A:
pass

class B(A):
pass

assert TypeParser.is_subclass(MyTuple[A], ty.Tuple[A])
assert TypeParser.is_subclass(ty.Tuple[B], ty.Tuple[A])
assert TypeParser.is_subclass(MyTuple[B], ty.Tuple[A])
assert not TypeParser.is_subclass(ty.Tuple[A], ty.Tuple[B])
assert not TypeParser.is_subclass(ty.Tuple[A], MyTuple[A])
assert not TypeParser.is_subclass(MyTuple[A], ty.Tuple[B])
assert TypeParser.is_subclass(MyTuple[A, int], ty.Tuple[A, int])
assert TypeParser.is_subclass(ty.Tuple[B, int], ty.Tuple[A, int])
assert TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[A, int])
assert TypeParser.is_subclass(MyTuple[int, B], ty.Tuple[int, A])
assert not TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[int, A])
assert not TypeParser.is_subclass(MyTuple[int, B], ty.Tuple[A, int])
assert not TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[A])
assert not TypeParser.is_subclass(MyTuple[B], ty.Tuple[A, int])


def test_type_is_instance1():
assert TypeParser.is_instance(File, ty.Type[File])

Expand Down
65 changes: 63 additions & 2 deletions pydra/utils/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from fileformats.generic import File
from fileformats.core.mixin import WithSeparateHeader
from fileformats.core.mixin import WithSeparateHeader, WithMagicNumber
from pydra import mark
from pydra.engine.task import ShellCommandTask
from pydra.engine import specs


class MyFormat(File):
class MyFormat(WithMagicNumber, File):
ext = ".my"
magic_number = b"MYFORMAT"


class MyHeader(File):
Expand All @@ -17,6 +18,12 @@ class MyFormatX(WithSeparateHeader, MyFormat):
header_type = MyHeader


class MyOtherFormatX(WithMagicNumber, WithSeparateHeader, File):
magic_number = b"MYFORMAT"
ext = ".my"
header_type = MyHeader


@mark.task
def generic_func_task(in_file: File) -> File:
return in_file
Expand Down Expand Up @@ -118,3 +125,57 @@ class SpecificShellTask(ShellCommandTask):
input_spec = specific_shell_input_spec
output_spec = specific_shelloutput_spec
executable = "echo"


@mark.task
def other_specific_func_task(in_file: MyOtherFormatX) -> MyOtherFormatX:
return in_file


other_specific_shell_input_fields = [
(
"in_file",
MyOtherFormatX,
{
"help_string": "the input file",
"argstr": "",
"copyfile": "copy",
"sep": " ",
},
),
(
"out",
str,
{
"help_string": "output file name",
"argstr": "",
"position": -1,
"output_file_template": "{in_file}", # Pass through un-altered
},
),
]

other_specific_shell_input_spec = specs.SpecInfo(
name="Input", fields=other_specific_shell_input_fields, bases=(specs.ShellSpec,)
)

other_specific_shell_output_fields = [
(
"out",
MyOtherFormatX,
{
"help_string": "output file",
},
),
]
other_specific_shelloutput_spec = specs.SpecInfo(
name="Output",
fields=other_specific_shell_output_fields,
bases=(specs.ShellOutSpec,),
)


class OtherSpecificShellTask(ShellCommandTask):
input_spec = other_specific_shell_input_spec
output_spec = other_specific_shelloutput_spec
executable = "echo"
Loading

0 comments on commit 1858668

Please sign in to comment.