diff --git a/pydra/engine/helpers.py b/pydra/engine/helpers.py index 5b411a9a9..4d8e84132 100644 --- a/pydra/engine/helpers.py +++ b/pydra/engine/helpers.py @@ -263,7 +263,9 @@ def make_klass(spec): **kwargs, ) checker_label = f"'{name}' field of {spec.name}" - type_checker = TypeParser[newfield.type](newfield.type, label=checker_label) + type_checker = TypeParser[newfield.type]( + newfield.type, label=checker_label, superclass_auto_cast=True + ) if newfield.type in (MultiInputObj, MultiInputFile): converter = attr.converters.pipe(ensure_list, type_checker) elif newfield.type in (MultiOutputObj, MultiOutputFile): diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index 02f14a78d..16cd925ce 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -449,16 +449,19 @@ def collect_additional_outputs(self, inputs, output_dir, outputs): ), ): raise TypeError( - f"Support for {fld.type} type, required for {fld.name} in {self}, " + f"Support for {fld.type} type, required for '{fld.name}' in {self}, " "has not been implemented in collect_additional_output" ) # assuming that field should have either default or metadata, but not both input_value = getattr(inputs, fld.name, attr.NOTHING) if input_value is not attr.NOTHING: if TypeParser.contains_type(FileSet, fld.type): - label = f"output field '{fld.name}' of {self}" - input_value = TypeParser(fld.type, label=label).coerce(input_value) - additional_out[fld.name] = input_value + if input_value is not False: + label = f"output field '{fld.name}' of {self}" + input_value = TypeParser(fld.type, label=label).coerce( + input_value + ) + additional_out[fld.name] = input_value elif ( fld.default is None or fld.default == attr.NOTHING ) and not fld.metadata: # TODO: is it right? diff --git a/pydra/engine/submitter.py b/pydra/engine/submitter.py index 6effed253..e2610c9bd 100644 --- a/pydra/engine/submitter.py +++ b/pydra/engine/submitter.py @@ -1,9 +1,10 @@ """Handle execution backends.""" import asyncio +import typing as ty import pickle from uuid import uuid4 -from .workers import WORKERS +from .workers import Worker, WORKERS from .core import is_workflow from .helpers import get_open_loop, load_and_run_async @@ -16,24 +17,34 @@ class Submitter: """Send a task to the execution backend.""" - def __init__(self, plugin="cf", **kwargs): + def __init__(self, plugin: ty.Union[str, ty.Type[Worker]] = "cf", **kwargs): """ Initialize task submission. Parameters ---------- - plugin : :obj:`str` - The identifier of the execution backend. + plugin : :obj:`str` or :obj:`ty.Type[pydra.engine.core.Worker]` + Either the identifier of the execution backend or the worker class itself. Default is ``cf`` (Concurrent Futures). + **kwargs + Additional keyword arguments to pass to the worker. """ self.loop = get_open_loop() self._own_loop = not self.loop.is_running() - self.plugin = plugin - try: - self.worker = WORKERS[self.plugin](**kwargs) - except KeyError: - raise NotImplementedError(f"No worker for {self.plugin}") + if isinstance(plugin, str): + self.plugin = plugin + try: + worker_cls = WORKERS[self.plugin] + except KeyError: + raise NotImplementedError(f"No worker for '{self.plugin}' plugin") + else: + try: + self.plugin = plugin.plugin_name + except AttributeError: + raise ValueError("Worker class must have a 'plugin_name' str attribute") + worker_cls = plugin + self.worker = worker_cls(**kwargs) self.worker.loop = self.loop def __call__(self, runnable, cache_locations=None, rerun=False, environment=None): diff --git a/pydra/engine/tests/test_node_task.py b/pydra/engine/tests/test_node_task.py index 4e182781b..37ed90d03 100644 --- a/pydra/engine/tests/test_node_task.py +++ b/pydra/engine/tests/test_node_task.py @@ -133,21 +133,7 @@ def test_task_init_3a( def test_task_init_4(): - """task with interface and inputs. splitter set using split method""" - nn = fun_addtwo(name="NA") - nn.split(splitter="a", a=[3, 5]) - assert np.allclose(nn.inputs.a, [3, 5]) - - assert nn.state.splitter == "NA.a" - assert nn.state.splitter_rpn == ["NA.a"] - - nn.state.prepare_states(nn.inputs) - assert nn.state.states_ind == [{"NA.a": 0}, {"NA.a": 1}] - assert nn.state.states_val == [{"NA.a": 3}, {"NA.a": 5}] - - -def test_task_init_4a(): - """task with a splitter and inputs set in the split method""" + """task with interface splitter and inputs set in the split method""" nn = fun_addtwo(name="NA") nn.split(splitter="a", a=[3, 5]) assert np.allclose(nn.inputs.a, [3, 5]) diff --git a/pydra/engine/tests/test_shelltask.py b/pydra/engine/tests/test_shelltask.py index a13bbc52c..4857db094 100644 --- a/pydra/engine/tests/test_shelltask.py +++ b/pydra/engine/tests/test_shelltask.py @@ -4347,6 +4347,102 @@ def change_name(file): # res = shelly(plugin="cf") +def test_shell_cmd_optional_output_file1(tmp_path): + """ + Test to see that 'unused' doesn't complain about not having an output passed to it + """ + my_cp_spec = SpecInfo( + name="Input", + fields=[ + ( + "input", + attr.ib( + type=File, metadata={"argstr": "", "help_string": "input file"} + ), + ), + ( + "output", + attr.ib( + type=Path, + metadata={ + "argstr": "", + "output_file_template": "out.txt", + "help_string": "output file", + }, + ), + ), + ( + "unused", + attr.ib( + type=ty.Union[Path, bool], + default=False, + metadata={ + "argstr": "--not-used", + "output_file_template": "out.txt", + "help_string": "dummy output", + }, + ), + ), + ], + bases=(ShellSpec,), + ) + + my_cp = ShellCommandTask( + name="my_cp", + executable="cp", + input_spec=my_cp_spec, + ) + file1 = tmp_path / "file1.txt" + file1.write_text("foo") + result = my_cp(input=file1, unused=False) + assert result.output.output.fspath.read_text() == "foo" + + +def test_shell_cmd_optional_output_file2(tmp_path): + """ + Test to see that 'unused' doesn't complain about not having an output passed to it + """ + my_cp_spec = SpecInfo( + name="Input", + fields=[ + ( + "input", + attr.ib( + type=File, metadata={"argstr": "", "help_string": "input file"} + ), + ), + ( + "output", + attr.ib( + type=ty.Union[Path, bool], + default=False, + metadata={ + "argstr": "", + "output_file_template": "out.txt", + "help_string": "dummy output", + }, + ), + ), + ], + bases=(ShellSpec,), + ) + + my_cp = ShellCommandTask( + name="my_cp", + executable="cp", + input_spec=my_cp_spec, + ) + file1 = tmp_path / "file1.txt" + file1.write_text("foo") + result = my_cp(input=file1, output=True) + assert result.output.output.fspath.read_text() == "foo" + + file2 = tmp_path / "file2.txt" + file2.write_text("bar") + with pytest.raises(RuntimeError): + my_cp(input=file2, output=False) + + def test_shell_cmd_non_existing_outputs_1(tmp_path): """Checking that non existing output files do not return a phantom path, but return NOTHING instead""" diff --git a/pydra/engine/tests/test_submitter.py b/pydra/engine/tests/test_submitter.py index 761c11d88..7098f6688 100644 --- a/pydra/engine/tests/test_submitter.py +++ b/pydra/engine/tests/test_submitter.py @@ -6,6 +6,8 @@ import attrs import typing as ty from random import randint +import os +from unittest.mock import patch import pytest from fileformats.generic import Directory from .utils import ( @@ -15,8 +17,9 @@ gen_basic_wf_with_threadcount, gen_basic_wf_with_threadcount_concurrent, ) -from ..core import Workflow +from ..core import Workflow, TaskBase from ..submitter import Submitter +from ..workers import SerialWorker from ... import mark from pathlib import Path from datetime import datetime @@ -665,3 +668,65 @@ def to_tuple(x, y): ): with Submitter("cf") as sub: result = sub(wf) + +@mark.task +def to_tuple(x, y): + return (x, y) + + +class BYOAddVarWorker(SerialWorker): + """A dummy worker that adds 1 to the output of the task""" + + plugin_name = "byo_add_env_var" + + def __init__(self, add_var, **kwargs): + super().__init__(**kwargs) + self.add_var = add_var + + async def exec_serial(self, runnable, rerun=False, environment=None): + if isinstance(runnable, TaskBase): + with patch.dict(os.environ, {"BYO_ADD_VAR": str(self.add_var)}): + result = runnable._run(rerun, environment=environment) + return result + else: # it could be tuple that includes pickle files with tasks and inputs + return super().exec_serial(runnable, rerun, environment) + + +@mark.task +def add_env_var_task(x: int) -> int: + return x + int(os.environ.get("BYO_ADD_VAR", 0)) + + +def test_byo_worker(): + + task1 = add_env_var_task(x=1) + + with Submitter(plugin=BYOAddVarWorker, add_var=10) as sub: + assert sub.plugin == "byo_add_env_var" + result = task1(submitter=sub) + + assert result.output.out == 11 + + task2 = add_env_var_task(x=2) + + with Submitter(plugin="serial") as sub: + result = task2(submitter=sub) + + assert result.output.out == 2 + + +def test_bad_builtin_worker(): + + with pytest.raises(NotImplementedError, match="No worker for 'bad-worker' plugin"): + Submitter(plugin="bad-worker") + + +def test_bad_byo_worker(): + + class BadWorker: + pass + + with pytest.raises( + ValueError, match="Worker class must have a 'plugin_name' str attribute" + ): + Submitter(plugin=BadWorker) diff --git a/pydra/engine/workers.py b/pydra/engine/workers.py index 155a2800d..eaa40beb0 100644 --- a/pydra/engine/workers.py +++ b/pydra/engine/workers.py @@ -128,6 +128,8 @@ async def fetch_finished(self, futures): class SerialWorker(Worker): """A worker to execute linearly.""" + plugin_name = "serial" + def __init__(self, **kwargs): """Initialize worker.""" logger.debug("Initialize SerialWorker") @@ -157,6 +159,8 @@ async def fetch_finished(self, futures): class ConcurrentFuturesWorker(Worker): """A worker to execute in parallel using Python's concurrent futures.""" + plugin_name = "cf" + def __init__(self, n_procs=None): """Initialize Worker.""" super().__init__() @@ -192,6 +196,7 @@ def close(self): class SlurmWorker(DistributedWorker): """A worker to execute tasks on SLURM systems.""" + plugin_name = "slurm" _cmd = "sbatch" _sacct_re = re.compile( "(?P\\d*) +(?P\\w*)\\+? +" "(?P\\d+):\\d+" @@ -367,6 +372,8 @@ async def _verify_exit_code(self, jobid): class SGEWorker(DistributedWorker): """A worker to execute tasks on SLURM systems.""" + plugin_name = "sge" + _cmd = "qsub" _sacct_re = re.compile( "(?P\\d*) +(?P\\w*)\\+? +" "(?P\\d+):\\d+" @@ -860,6 +867,8 @@ class DaskWorker(Worker): This is an experimental implementation with limited testing. """ + plugin_name = "dask" + def __init__(self, **kwargs): """Initialize Worker.""" super().__init__() @@ -898,7 +907,7 @@ def close(self): class PsijWorker(Worker): """A worker to execute tasks using PSI/J.""" - def __init__(self, subtype, **kwargs): + def __init__(self, **kwargs): """ Initialize PsijWorker. @@ -915,15 +924,6 @@ def __init__(self, subtype, **kwargs): logger.debug("Initialize PsijWorker") self.psij = psij - # Check if the provided subtype is valid - valid_subtypes = ["local", "slurm"] - if subtype not in valid_subtypes: - raise ValueError( - f"Invalid 'subtype' provided. Available options: {', '.join(valid_subtypes)}" - ) - - self.subtype = subtype - def run_el(self, interface, rerun=False, **kwargs): """Run a task.""" return self.exec_psij(interface, rerun=rerun) @@ -1039,14 +1039,29 @@ def close(self): pass +class PsijLocalWorker(PsijWorker): + """A worker to execute tasks using PSI/J on the local machine.""" + + subtype = "local" + plugin_name = f"psij-{subtype}" + + +class PsijSlurmWorker(PsijWorker): + """A worker to execute tasks using PSI/J using SLURM.""" + + subtype = "slurm" + plugin_name = f"psij-{subtype}" + + WORKERS = { - "serial": SerialWorker, - "cf": ConcurrentFuturesWorker, - "slurm": SlurmWorker, - "dask": DaskWorker, - "sge": SGEWorker, - **{ - "psij-" + subtype: lambda subtype=subtype: PsijWorker(subtype=subtype) - for subtype in ["local", "slurm"] - }, + w.plugin_name: w + for w in ( + SerialWorker, + ConcurrentFuturesWorker, + SlurmWorker, + DaskWorker, + SGEWorker, + PsijLocalWorker, + PsijSlurmWorker, + ) } diff --git a/pydra/utils/tests/test_typing.py b/pydra/utils/tests/test_typing.py index f88aeafe1..665d79327 100644 --- a/pydra/utils/tests/test_typing.py +++ b/pydra/utils/tests/test_typing.py @@ -1,5 +1,6 @@ import os import itertools +import sys import typing as ty from pathlib import Path import tempfile @@ -8,13 +9,16 @@ from ...engine.specs import File, LazyOutField from ..typing import TypeParser from pydra import Workflow -from fileformats.application import Json +from fileformats.application import Json, Yaml, Xml from .utils import ( generic_func_task, GenericShellTask, specific_func_task, SpecificShellTask, + other_specific_func_task, + OtherSpecificShellTask, MyFormatX, + MyOtherFormatX, MyHeader, ) @@ -152,8 +156,12 @@ def test_type_check_nested6(): def test_type_check_nested7(): + TypeParser(ty.Tuple[float, float, float])(lz(ty.List[int])) + + +def test_type_check_nested7a(): with pytest.raises(TypeError, match="Wrong number of type arguments"): - TypeParser(ty.Tuple[float, float, float])(lz(ty.List[int])) + TypeParser(ty.Tuple[float, float, float])(lz(ty.Tuple[int])) def test_type_check_nested8(): @@ -164,6 +172,18 @@ def test_type_check_nested8(): )(lz(ty.List[float])) +def test_type_check_permit_superclass(): + # Typical case as Json is subclass of File + TypeParser(ty.List[File])(lz(ty.List[Json])) + # Permissive super class, as File is superclass of Json + TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[File])) + with pytest.raises(TypeError, match="Cannot coerce"): + TypeParser(ty.List[Json], superclass_auto_cast=False)(lz(ty.List[File])) + # Fails because Yaml is neither sub or super class of Json + with pytest.raises(TypeError, match="Cannot coerce"): + TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[Yaml])) + + def test_type_check_fail1(): with pytest.raises(TypeError, match="Wrong number of type arguments in tuple"): TypeParser(ty.Tuple[int, int, int])(lz(ty.Tuple[float, float, float, float])) @@ -490,14 +510,29 @@ def test_matches_type_tuple(): assert not TypeParser.matches_type(ty.Tuple[int], ty.Tuple[int, int]) -def test_matches_type_tuple_ellipsis(): +def test_matches_type_tuple_ellipsis1(): assert TypeParser.matches_type(ty.Tuple[int], ty.Tuple[int, ...]) + + +def test_matches_type_tuple_ellipsis2(): assert TypeParser.matches_type(ty.Tuple[int, int], ty.Tuple[int, ...]) + + +def test_matches_type_tuple_ellipsis3(): assert not TypeParser.matches_type(ty.Tuple[int, float], ty.Tuple[int, ...]) - assert not TypeParser.matches_type(ty.Tuple[int, ...], ty.Tuple[int]) + + +def test_matches_type_tuple_ellipsis4(): + assert TypeParser.matches_type(ty.Tuple[int, ...], ty.Tuple[int]) + + +def test_matches_type_tuple_ellipsis5(): assert TypeParser.matches_type( ty.Tuple[int], ty.List[int], coercible=[(tuple, list)] ) + + +def test_matches_type_tuple_ellipsis6(): assert TypeParser.matches_type( ty.Tuple[int, ...], ty.List[int], coercible=[(tuple, list)] ) @@ -538,7 +573,17 @@ def specific_task(request): assert False -def test_typing_cast(tmp_path, generic_task, specific_task): +@pytest.fixture(params=["func", "shell"]) +def other_specific_task(request): + if request.param == "func": + return other_specific_func_task + elif request.param == "shell": + return OtherSpecificShellTask + else: + assert False + + +def test_typing_implicit_cast_from_super(tmp_path, generic_task, specific_task): """Check the casting of lazy fields and whether specific file-sets can be recovered from generic `File` classes""" @@ -562,33 +607,86 @@ def test_typing_cast(tmp_path, generic_task, specific_task): ) ) + wf.add( + specific_task( + in_file=wf.generic.lzout.out, + name="specific2", + ) + ) + + wf.set_output( + [ + ("out_file", wf.specific2.lzout.out), + ] + ) + + in_file = MyFormatX.sample() + + result = wf(in_file=in_file, plugin="serial") + + out_file: MyFormatX = result.output.out_file + assert type(out_file) is MyFormatX + assert out_file.parent != in_file.parent + assert type(out_file.header) is MyHeader + assert out_file.header.parent != in_file.header.parent + + +def test_typing_cast(tmp_path, specific_task, other_specific_task): + """Check the casting of lazy fields and whether specific file-sets can be recovered + from generic `File` classes""" + + wf = Workflow( + name="test", + input_spec={"in_file": MyFormatX}, + output_spec={"out_file": MyFormatX}, + ) + + wf.add( + specific_task( + in_file=wf.lzin.in_file, + name="entry", + ) + ) + + with pytest.raises(TypeError, match="Cannot coerce"): + # No cast of generic task output to MyFormatX + wf.add( # Generic task + other_specific_task( + in_file=wf.entry.lzout.out, + name="inner", + ) + ) + + wf.add( # Generic task + other_specific_task( + in_file=wf.entry.lzout.out.cast(MyOtherFormatX), + name="inner", + ) + ) + with pytest.raises(TypeError, match="Cannot coerce"): # No cast of generic task output to MyFormatX wf.add( specific_task( - in_file=wf.generic.lzout.out, - name="specific2", + in_file=wf.inner.lzout.out, + name="exit", ) ) wf.add( specific_task( - in_file=wf.generic.lzout.out.cast(MyFormatX), - name="specific2", + in_file=wf.inner.lzout.out.cast(MyFormatX), + name="exit", ) ) wf.set_output( [ - ("out_file", wf.specific2.lzout.out), + ("out_file", wf.exit.lzout.out), ] ) - my_fspath = tmp_path / "in_file.my" - hdr_fspath = tmp_path / "in_file.hdr" - my_fspath.write_text("my-format") - hdr_fspath.write_text("my-header") - in_file = MyFormatX([my_fspath, hdr_fspath]) + in_file = MyFormatX.sample() result = wf(in_file=in_file, plugin="serial") @@ -611,6 +709,63 @@ def test_type_is_subclass3(): assert TypeParser.is_subclass(ty.Type[Json], ty.Type[File]) +def test_union_is_subclass1(): + assert TypeParser.is_subclass(ty.Union[Json, Yaml], ty.Union[Json, Yaml, Xml]) + + +def test_union_is_subclass2(): + assert not TypeParser.is_subclass(ty.Union[Json, Yaml, Xml], ty.Union[Json, Yaml]) + + +def test_union_is_subclass3(): + assert TypeParser.is_subclass(Json, ty.Union[Json, Yaml]) + + +def test_union_is_subclass4(): + assert not TypeParser.is_subclass(ty.Union[Json, Yaml], Json) + + +def test_generic_is_subclass1(): + assert TypeParser.is_subclass(ty.List[int], list) + + +def test_generic_is_subclass2(): + assert not TypeParser.is_subclass(list, ty.List[int]) + + +def test_generic_is_subclass3(): + assert not TypeParser.is_subclass(ty.List[float], ty.List[int]) + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Cannot subscript tuple in < Py3.9" +) +def test_generic_is_subclass4(): + class MyTuple(tuple): + pass + + class A: + pass + + class B(A): + pass + + assert TypeParser.is_subclass(MyTuple[A], ty.Tuple[A]) + assert TypeParser.is_subclass(ty.Tuple[B], ty.Tuple[A]) + assert TypeParser.is_subclass(MyTuple[B], ty.Tuple[A]) + assert not TypeParser.is_subclass(ty.Tuple[A], ty.Tuple[B]) + assert not TypeParser.is_subclass(ty.Tuple[A], MyTuple[A]) + assert not TypeParser.is_subclass(MyTuple[A], ty.Tuple[B]) + assert TypeParser.is_subclass(MyTuple[A, int], ty.Tuple[A, int]) + assert TypeParser.is_subclass(ty.Tuple[B, int], ty.Tuple[A, int]) + assert TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[A, int]) + assert TypeParser.is_subclass(MyTuple[int, B], ty.Tuple[int, A]) + assert not TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[int, A]) + assert not TypeParser.is_subclass(MyTuple[int, B], ty.Tuple[A, int]) + assert not TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[A]) + assert not TypeParser.is_subclass(MyTuple[B], ty.Tuple[A, int]) + + def test_type_is_instance1(): assert TypeParser.is_instance(File, ty.Type[File]) diff --git a/pydra/utils/tests/utils.py b/pydra/utils/tests/utils.py index eb452edf9..3582fa9ed 100644 --- a/pydra/utils/tests/utils.py +++ b/pydra/utils/tests/utils.py @@ -1,12 +1,13 @@ from fileformats.generic import File -from fileformats.core.mixin import WithSeparateHeader +from fileformats.core.mixin import WithSeparateHeader, WithMagicNumber from pydra import mark from pydra.engine.task import ShellCommandTask from pydra.engine import specs -class MyFormat(File): +class MyFormat(WithMagicNumber, File): ext = ".my" + magic_number = b"MYFORMAT" class MyHeader(File): @@ -17,6 +18,12 @@ class MyFormatX(WithSeparateHeader, MyFormat): header_type = MyHeader +class MyOtherFormatX(WithMagicNumber, WithSeparateHeader, File): + magic_number = b"MYFORMAT" + ext = ".my" + header_type = MyHeader + + @mark.task def generic_func_task(in_file: File) -> File: return in_file @@ -118,3 +125,57 @@ class SpecificShellTask(ShellCommandTask): input_spec = specific_shell_input_spec output_spec = specific_shelloutput_spec executable = "echo" + + +@mark.task +def other_specific_func_task(in_file: MyOtherFormatX) -> MyOtherFormatX: + return in_file + + +other_specific_shell_input_fields = [ + ( + "in_file", + MyOtherFormatX, + { + "help_string": "the input file", + "argstr": "", + "copyfile": "copy", + "sep": " ", + }, + ), + ( + "out", + str, + { + "help_string": "output file name", + "argstr": "", + "position": -1, + "output_file_template": "{in_file}", # Pass through un-altered + }, + ), +] + +other_specific_shell_input_spec = specs.SpecInfo( + name="Input", fields=other_specific_shell_input_fields, bases=(specs.ShellSpec,) +) + +other_specific_shell_output_fields = [ + ( + "out", + MyOtherFormatX, + { + "help_string": "output file", + }, + ), +] +other_specific_shelloutput_spec = specs.SpecInfo( + name="Output", + fields=other_specific_shell_output_fields, + bases=(specs.ShellOutSpec,), +) + + +class OtherSpecificShellTask(ShellCommandTask): + input_spec = other_specific_shell_input_spec + output_spec = other_specific_shelloutput_spec + executable = "echo" diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index ceddc7e21..ee8e733e4 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -4,6 +4,7 @@ import os import sys import typing as ty +import logging import attr from ..engine.specs import ( LazyField, @@ -19,6 +20,7 @@ # Python < 3.8 from typing_extensions import get_origin, get_args # type: ignore +logger = logging.getLogger("pydra") NO_GENERIC_ISSUBCLASS = sys.version_info.major == 3 and sys.version_info.minor < 10 @@ -56,6 +58,9 @@ class TypeParser(ty.Generic[T]): the tree of more complex nested container types. Overrides 'coercible' to enable you to carve out exceptions, such as TypeParser(list, coercible=[(ty.Iterable, list)], not_coercible=[(str, list)]) + superclass_auto_cast : bool + Allow lazy fields to pass the type check if their types are superclasses of the + specified pattern (instead of matching or being subclasses of the pattern) label : str the label to be used to identify the type parser in error messages. Especially useful when TypeParser is used as a converter in attrs.fields @@ -64,6 +69,7 @@ class TypeParser(ty.Generic[T]): tp: ty.Type[T] coercible: ty.List[ty.Tuple[TypeOrAny, TypeOrAny]] not_coercible: ty.List[ty.Tuple[TypeOrAny, TypeOrAny]] + superclass_auto_cast: bool label: str COERCIBLE_DEFAULT: ty.Tuple[ty.Tuple[type, type], ...] = ( @@ -107,6 +113,7 @@ def __init__( not_coercible: ty.Optional[ ty.Iterable[ty.Tuple[TypeOrAny, TypeOrAny]] ] = NOT_COERCIBLE_DEFAULT, + superclass_auto_cast: bool = False, label: str = "", ): def expand_pattern(t): @@ -135,6 +142,7 @@ def expand_pattern(t): ) self.not_coercible = list(not_coercible) if not_coercible is not None else [] self.pattern = expand_pattern(tp) + self.superclass_auto_cast = superclass_auto_cast def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: """Attempts to coerce the object to the specified type, unless the value is @@ -161,7 +169,27 @@ def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: if obj is attr.NOTHING: coerced = attr.NOTHING # type: ignore[assignment] elif isinstance(obj, LazyField): - self.check_type(obj.type) + try: + self.check_type(obj.type) + except TypeError as e: + if self.superclass_auto_cast: + try: + # Check whether the type of the lazy field isn't a superclass of + # the type to check against, and if so, allow it due to permissive + # typing rules. + TypeParser(obj.type).check_type(self.tp) + except TypeError: + raise e + else: + logger.info( + "Connecting lazy field %s to %s%s via permissive typing that " + "allows super-to-sub type connections", + obj, + self.tp, + self.label_str, + ) + else: + raise e coerced = obj # type: ignore elif isinstance(obj, StateArray): coerced = StateArray(self(o) for o in obj) # type: ignore[assignment] @@ -421,6 +449,10 @@ def check_tuple(tp_args, pattern_args): for arg in tp_args: expand_and_check(arg, pattern_args[0]) return + elif tp_args[-1] is Ellipsis: + for pattern_arg in pattern_args: + expand_and_check(tp_args[0], pattern_arg) + return if len(tp_args) != len(pattern_args): raise TypeError( f"Wrong number of type arguments in tuple {tp_args} compared to pattern " @@ -464,8 +496,17 @@ def check_coercible( explicit inclusions and exclusions set in the `coercible` and `not_coercible` member attrs """ + # Short-circuit the basic cases where the source and target are the same if source is target: return + if self.superclass_auto_cast and self.is_subclass(target, type(source)): + logger.info( + "Attempting to coerce %s into %s due to super-to-sub class coercion " + "being permitted", + source, + target, + ) + return source_origin = get_origin(source) if source_origin is not None: source = source_origin @@ -562,7 +603,7 @@ def matches_type( def is_instance( cls, obj: object, - candidates: ty.Union[ty.Type[ty.Any], ty.Iterable[ty.Type[ty.Any]]], + candidates: ty.Union[ty.Type[ty.Any], ty.Sequence[ty.Type[ty.Any]]], ) -> bool: """Checks whether the object is an instance of cls or that cls is typing.Any, extending the built-in isinstance to check nested type args @@ -574,7 +615,7 @@ def is_instance( candidates : type or ty.Iterable[type] the candidate types to check the object against """ - if not isinstance(candidates, (tuple, list)): + if not isinstance(candidates, ty.Sequence): candidates = [candidates] for candidate in candidates: if candidate is ty.Any: @@ -600,7 +641,7 @@ def is_instance( def is_subclass( cls, klass: ty.Type[ty.Any], - candidates: ty.Union[ty.Type[ty.Any], ty.Iterable[ty.Type[ty.Any]]], + candidates: ty.Union[ty.Type[ty.Any], ty.Sequence[ty.Type[ty.Any]]], any_ok: bool = False, ) -> bool: """Checks whether the class a is either the same as b, a subclass of b or b is @@ -617,16 +658,23 @@ def is_subclass( """ if not isinstance(candidates, ty.Sequence): candidates = [candidates] + if ty.Any in candidates: + return True + if klass is ty.Any: + return any_ok + + origin = get_origin(klass) + args = get_args(klass) for candidate in candidates: + candidate_origin = get_origin(candidate) + candidate_args = get_args(candidate) # Handle ty.Type[*] types in klass and candidates - if ty.get_origin(klass) is type and ( - candidate is type or ty.get_origin(candidate) is type - ): + if origin is type and (candidate is type or candidate_origin is type): if candidate is type: return True - return cls.is_subclass(ty.get_args(klass)[0], ty.get_args(candidate)[0]) - elif ty.get_origin(klass) is type or ty.get_origin(candidate) is type: + return cls.is_subclass(args[0], candidate_args[0]) + elif origin is type or candidate_origin is type: return False if NO_GENERIC_ISSUBCLASS: if klass is type and candidate is not type: @@ -636,27 +684,29 @@ def is_subclass( ): return True else: - if klass is ty.Any: - if ty.Any in candidates: # type: ignore - return True - else: - return any_ok - origin = get_origin(klass) if origin is ty.Union: - args = get_args(klass) - if get_origin(candidate) is ty.Union: - candidate_args = get_args(candidate) - else: - candidate_args = [candidate] - return all( - any(cls.is_subclass(a, c) for a in args) for c in candidate_args + union_args = ( + candidate_args if candidate_origin is ty.Union else (candidate,) ) - if origin is not None: - klass = origin - if klass is candidate or candidate is ty.Any: - return True - if issubclass(klass, candidate): - return True + matches = all( + any(cls.is_subclass(a, c) for c in union_args) for a in args + ) + if matches: + return True + else: + if candidate_args and candidate_origin is not ty.Union: + if ( + origin + and issubclass(origin, candidate_origin) # type: ignore[arg-type] + and len(args) == len(candidate_args) + and all( + issubclass(a, c) for a, c in zip(args, candidate_args) + ) + ): + return True + else: + if issubclass(origin if origin else klass, candidate): + return True return False @classmethod