Skip to content

Commit

Permalink
Added new method ArgumentParser.add_instantiator that enables develop…
Browse files Browse the repository at this point in the history
…ers to implement custom instantiation (#170).
  • Loading branch information
mauvilsa committed Aug 23, 2023
1 parent 5798fe3 commit fdbfe3b
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 16 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ Added
<https://github.com/omni-us/jsonargparse/issues/337#issuecomment-1665055459>`__).
- Improved resolving of nested forward references in types.
- The ``ext_vars`` for an ``ActionJsonnet`` argument can now have a default.
- New method ``ArgumentParser.add_instantiator`` that enables developers to
implement custom instantiation (`#170
<https://github.com/omni-us/jsonargparse/issues/170>`__).

Deprecated
^^^^^^^^^^
Expand Down
5 changes: 3 additions & 2 deletions jsonargparse/_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from contextvars import ContextVar
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from ._common import is_subclass, parser_context
from ._common import get_class_instantiator, is_subclass, parser_context
from ._loaders_dumpers import get_loader_exceptions, load_value
from ._namespace import Namespace, split_key, split_key_root
from ._optionals import FilesCompleterMethod, get_config_read_mode
Expand Down Expand Up @@ -324,7 +324,8 @@ def check_type(self, value, parser):
return self._load_config(value, parser)

def instantiate_classes(self, value):
return self.basetype(**value)
instantiator_fn = get_class_instantiator()
return instantiator_fn(self.basetype, **value)


class _ActionHelpClassPath(Action):
Expand Down
44 changes: 43 additions & 1 deletion jsonargparse/_common.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
import dataclasses
import inspect
import sys
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Optional, Union
from typing import Dict, Optional, Tuple, Type, TypeVar, Union

from ._namespace import Namespace
from ._type_checking import ArgumentParser

ClassType = TypeVar("ClassType")

if sys.version_info < (3, 8):
from typing import Callable

InstantiatorCallable = Callable[..., ClassType]
else:
from typing import Protocol

class InstantiatorCallable(Protocol):
def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
pass # pragma: no cover


InstantiatorsDictType = Dict[Tuple[type, bool], InstantiatorCallable]


parent_parser: ContextVar["ArgumentParser"] = ContextVar("parent_parser")
parser_capture: ContextVar[bool] = ContextVar("parser_capture", default=False)
defaults_cache: ContextVar[Optional[Namespace]] = ContextVar("defaults_cache", default=None)
lenient_check: ContextVar[Union[bool, str]] = ContextVar("lenient_check", default=False)
load_value_mode: ContextVar[Optional[str]] = ContextVar("load_value_mode", default=None)
class_instantiators: ContextVar[Optional[InstantiatorsDictType]] = ContextVar("class_instantiators")


parser_context_vars = dict(
Expand All @@ -20,6 +39,7 @@
defaults_cache=defaults_cache,
lenient_check=lenient_check,
load_value_mode=load_value_mode,
class_instantiators=class_instantiators,
)


Expand Down Expand Up @@ -70,3 +90,25 @@ def is_dataclass_like(cls) -> bool:
if attrs.has(cls):
return True
return all_dataclasses


def default_class_instantiator(class_type: Type[ClassType], *args, **kwargs) -> ClassType:
return class_type(*args, **kwargs)


class ClassInstantiator:
def __init__(self, instantiators: InstantiatorsDictType) -> None:
self.instantiators = instantiators

def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
for (cls, subclasses), instantiator in self.instantiators.items():
if class_type is cls or (subclasses and is_subclass(class_type, cls)):
return instantiator(class_type, *args, **kwargs)
return default_class_instantiator(class_type, *args, **kwargs)


def get_class_instantiator() -> InstantiatorCallable:
instantiators = class_instantiators.get()
if not instantiators:
return default_class_instantiator
return ClassInstantiator(instantiators)
53 changes: 50 additions & 3 deletions jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@
parent_parsers,
previous_config,
)
from ._common import is_dataclass_like, lenient_check, parser_context
from ._common import (
InstantiatorCallable,
InstantiatorsDictType,
is_dataclass_like,
lenient_check,
parser_context,
)
from ._deprecated import ParserDeprecations
from ._formatters import DefaultHelpFormatter, empty_help, get_env_var
from ._jsonnet import ActionJsonnet
Expand Down Expand Up @@ -73,6 +79,7 @@
from ._signatures import SignatureArguments
from ._typehints import ActionTypeHint, is_subclass_spec
from ._util import (
ClassType,
Path,
argument_error,
change_to_path_dir,
Expand Down Expand Up @@ -176,6 +183,7 @@ class ArgumentParser(ParserDeprecations, ActionsContainer, ArgumentLinking, argp
formatter_class: Type[DefaultHelpFormatter]
groups: Optional[Dict[str, "_ArgumentGroup"]] = None
_subcommands_action: Optional[_ActionSubCommands] = None
_instantiators: Optional[InstantiatorsDictType] = None

def __init__(
self,
Expand Down Expand Up @@ -1073,6 +1081,45 @@ def check_values(cfg):
message = prefix + message
raise type(ex)(message) from ex

def add_instantiator(
self,
instantiator: InstantiatorCallable,
class_type: Type[ClassType],
subclasses: bool = True,
prepend: bool = False,
) -> None:
"""Adds a custom instantiator for a class type. Used by ``instantiate_classes``.
Instantiator functions are expected to have as signature ``(class_type:
Type[ClassType], *args, **kwargs) -> ClassType``.
For reference, the default instantiator is ``return class_type(*args,
**kwargs)``.
Args:
instantiator: Function that instantiates a class.
class_type: The class type to instantiate.
subclasses: Whether to instantiate subclasses of ``class_type``.
prepend: Whether to prepend the instantiator to the existing instantiators.
"""
if self._instantiators is None:
self._instantiators = {}
key = (class_type, subclasses)
instantiators = {k: v for k, v in self._instantiators.items() if k != key}
if prepend:
self._instantiators = {key: instantiator, **instantiators}
else:
instantiators[key] = instantiator
self._instantiators = instantiators

def _get_instantiators(self):
instantiators = self._instantiators or {}
if hasattr(self, "parent_parser"):
parent_instantiators = self.parent_parser._get_instantiators()
instantiators = instantiators.copy()
instantiators.update({k: v for k, v in parent_instantiators.items() if k not in instantiators})
return instantiators

def instantiate_classes(
self,
cfg: Namespace,
Expand Down Expand Up @@ -1113,10 +1160,10 @@ def instantiate_classes(
pass
else:
if value is not None:
with parser_context(parent_parser=self):
with parser_context(parent_parser=self, class_instantiators=self._get_instantiators()):
parent[key] = component.instantiate_classes(value)
else:
with parser_context(load_value_mode=self.parser_mode):
with parser_context(load_value_mode=self.parser_mode, class_instantiators=self._get_instantiators()):
component.instantiate_class(component, cfg)

ActionLink.apply_instantiation_links(self, cfg, order=order)
Expand Down
5 changes: 3 additions & 2 deletions jsonargparse/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, Callable, List, Optional, Set, Tuple, Type, Union

from ._actions import _ActionConfigLoad
from ._common import is_dataclass_like, is_subclass
from ._common import get_class_instantiator, is_dataclass_like, is_subclass
from ._optionals import get_doc_short_description, pydantic_support
from ._parameter_resolvers import (
ParamData,
Expand Down Expand Up @@ -558,7 +558,8 @@ def group_instantiate_class(group, cfg):
value = {}
parent = cfg
key = group.dest
parent[key] = group.group_class(**value)
instantiator_fn = get_class_instantiator()
parent[key] = instantiator_fn(group.group_class, **value)


def strip_title(value):
Expand Down
9 changes: 6 additions & 3 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
_is_action_value_list,
remove_actions,
)
from ._common import is_dataclass_like, is_subclass, parent_parser, parser_context
from ._common import get_class_instantiator, is_dataclass_like, is_subclass, parent_parser, parser_context
from ._loaders_dumpers import (
get_loader_exceptions,
load_value,
Expand Down Expand Up @@ -1044,13 +1044,16 @@ def adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev
if init_args:
value["init_args"] = init_args
return value

instantiator_fn = get_class_instantiator()

if skip_args:

def partial_instance(*args):
return val_class(*args, **{**init_args, **dict_kwargs})
return instantiator_fn(val_class, *args, **{**init_args, **dict_kwargs})

return partial_instance
return val_class(**{**init_args, **dict_kwargs})
return instantiator_fn(val_class, **{**init_args, **dict_kwargs})

prev_init_args = prev_val.get("init_args") if isinstance(prev_val, Namespace) else None

Expand Down
6 changes: 1 addition & 5 deletions jsonargparse/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@
Optional,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
)

from ._common import is_subclass, parser_capture, parser_context
from ._common import ClassType, is_subclass, parser_capture, parser_context
from ._deprecated import PathDeprecations
from ._loaders_dumpers import json_dump, load_value
from ._optionals import (
Expand Down Expand Up @@ -360,9 +359,6 @@ class ClassFromFunctionBase:
wrapped_function: Callable


ClassType = TypeVar("ClassType")


def class_from_function(
func: Callable[..., ClassType],
func_return: Optional[Type[ClassType]] = None,
Expand Down
14 changes: 14 additions & 0 deletions jsonargparse_tests/test_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,20 @@ def test_add_class_docstring_parse_fail(parser, logger):
assert "a1 description" not in help_str


def test_add_class_custom_instantiator(parser):
def instantiate(cls, **kwargs):
instance = cls(**kwargs)
instance.call = "custom"
return instance

parser.add_class_arguments(Class0, "a")
parser.add_instantiator(instantiate, Class0)
cfg = parser.parse_args([])
init = parser.instantiate_classes(cfg)
assert isinstance(init.a, Class0)
assert init.a.call == "custom"


# add_method_arguments tests


Expand Down
68 changes: 68 additions & 0 deletions jsonargparse_tests/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,74 @@ def test_subclass_in_subcommand_with_global_default_config_file(parser, subparse
assert cfg.fit.model.foo == 123


# custom instantiation tests


class CustomInstantiationBase:
pass


class CustomInstantiationSub(CustomInstantiationBase):
pass


def instantiator(value):
def instantiate(cls, **kwargs):
instance = cls(**kwargs)
instance.call = value
return instance

return instantiate


def test_custom_instantiation_argument_type(parser):
parser.add_argument("--cls", type=CustomInstantiationBase)
parser.add_instantiator(instantiator("argument type"), CustomInstantiationBase)
cfg = parser.parse_args(["--cls=CustomInstantiationBase"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.cls, CustomInstantiationBase)
assert init.cls.call == "argument type"


def test_custom_instantiation_unused_for_subclass(parser):
parser.add_argument("--cls", type=CustomInstantiationBase)
parser.add_instantiator(instantiator("base"), CustomInstantiationBase, subclasses=False)
cfg = parser.parse_args(["--cls=CustomInstantiationSub"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.cls, CustomInstantiationSub)
assert not hasattr(init.cls, "call")


def test_custom_instantiation_used_for_subclass(parser):
parser.add_argument("--cls", type=CustomInstantiationBase)
parser.add_instantiator(instantiator("subclass"), CustomInstantiationBase, subclasses=True)
cfg = parser.parse_args(["--cls=CustomInstantiationSub"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.cls, CustomInstantiationSub)
assert init.cls.call == "subclass"


def test_custom_instantiation_prepend(parser):
parser.add_argument("--cls", type=CustomInstantiationBase)
parser.add_instantiator(instantiator("first"), CustomInstantiationSub)
parser.add_instantiator(instantiator("prepended"), CustomInstantiationBase, subclasses=True, prepend=True)
assert len(parser._instantiators) == 2
cfg = parser.parse_args(["--cls=CustomInstantiationSub"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.cls, CustomInstantiationSub)
assert init.cls.call == "prepended"


def test_custom_instantiation_replace(parser):
first_instantiator = instantiator("first")
second_instantiator = instantiator("second")
parser.add_argument("--cls", type=CustomInstantiationBase)
parser.add_instantiator(first_instantiator, CustomInstantiationBase)
parser.add_instantiator(second_instantiator, CustomInstantiationBase)
assert len(parser._instantiators) == 1
assert list(parser._instantiators.values())[0] is second_instantiator


# environment tests


Expand Down
21 changes: 21 additions & 0 deletions jsonargparse_tests/test_subcommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
strip_meta,
)
from jsonargparse_tests.conftest import get_parse_args_stdout, get_parser_help
from jsonargparse_tests.test_subclasses import CustomInstantiationBase, instantiator


@pytest.fixture
Expand Down Expand Up @@ -266,3 +267,23 @@ def test_subsubcommands_wrong_add_order(parser):
with pytest.raises(ValueError) as ctx:
subcommands1.add_subcommand("a", parser_s1_a)
ctx.match("Multiple levels of subcommands must be added in level order")


def test_subcommands_custom_instantiator(parser, subparser, subtests):
subparser.add_argument("--cls", type=CustomInstantiationBase)
subcommands = parser.add_subcommands()
subcommands.add_subcommand("cmd", subparser)

with subtests.test("main parser"):
parser.add_instantiator(instantiator("main parser"), CustomInstantiationBase)
cfg = parser.parse_args(["cmd", "--cls", "CustomInstantiationBase"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.cmd.cls, CustomInstantiationBase)
assert init.cmd.cls.call == "main parser"

with subtests.test("subparser"):
subparser.add_instantiator(instantiator("subparser"), CustomInstantiationBase)
cfg = parser.parse_args(["cmd", "--cls", "CustomInstantiationBase"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.cmd.cls, CustomInstantiationBase)
assert init.cmd.cls.call == "subparser"

0 comments on commit fdbfe3b

Please sign in to comment.