Skip to content

Commit

Permalink
Fix AST resolver kwargs.pop() with conflicting defaults not setting t…
Browse files Browse the repository at this point in the history
…he conditional default (#362).
  • Loading branch information
mauvilsa committed Aug 31, 2023
1 parent cd165e9 commit 4148f4e
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 51 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ Fixed
<https://github.com/omni-us/jsonargparse/issues/317>`__).
- Dataclass nested in list not setting defaults (`#357
<https://github.com/omni-us/jsonargparse/issues/357>`__)
- AST resolver ``kwargs.pop()`` with conflicting defaults not setting the
conditional default (`#362
<https://github.com/omni-us/jsonargparse/issues/362>`__).


v4.24.0 (2023-08-23)
Expand Down
14 changes: 9 additions & 5 deletions DOCUMENTATION.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1693,11 +1693,11 @@ A special case which is supported but with caveats, is multiple calls that use
The resolved parameters that have the same type hint and default across all
calls are supported normally. When there is a discrepancy between the calls, the
parameters behave differently and are shown in the help in special "Conditional
arguments" sections. The main difference is that these arguments are not
included in :py:meth:`.ArgumentParser.get_defaults` or the output of
``--print_config``. This is necessary because the parser does not know which of
the calls will be used at runtime, and adding them would cause
parameters behave differently and are shown in the help with the default like
``Conditional<ast-resolver> {DEFAULT_1, ...}``. The main difference is that
these parameters are not included in :py:meth:`.ArgumentParser.get_defaults` or
the output of ``--print_config``. This is necessary because the parser does not
know which of the calls will be used at runtime, and adding them would cause
:py:meth:`.ArgumentParser.instantiate_classes` to fail due to unexpected keyword
arguments.

Expand Down Expand Up @@ -1747,6 +1747,10 @@ Without the stubs resolver, the
given to the ``a`` and ``b`` arguments, instead of ``float``. And this means
that the parser would not fail if given an invalid value, for instance a string.

It is not possible to know the defaults of parameters discovered only because of
the stubs. In these cases in the parser help the default is shown as
``Unknown<stubs-resolver>`` and not included in
:py:meth:`.ArgumentParser.get_defaults` or the output of ``--print_config``.

.. _sub-classes:

Expand Down
77 changes: 36 additions & 41 deletions jsonargparse/_parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ class ParamData:
parameter_attributes = [s[1:] for s in inspect.Parameter.__slots__] # type: ignore
kinds = inspect._ParameterKind
ast_assign_type: Tuple[Type[ast.AST], ...] = (ast.AnnAssign, ast.Assign)
param_kwargs_get = "**.get()"
param_kwargs_pop = "**.pop()"
param_kwargs_pop_or_get = "**.pop|get():"


class SourceNotAvailable(Exception):
Expand Down Expand Up @@ -304,29 +303,6 @@ def add_stub_types(stubs: Optional[Dict[str, Any]], params: ParamList, component
ast_literals = {ast.dump(ast.parse(v, mode="eval").body): partial(ast.literal_eval, v) for v in ["{}", "[]"]}


def get_kwargs_pop_or_get_parameter(node, component, parent, doc_params, log_debug):
name = ast_get_constant_value(node.args[0])
if ast_is_constant(node.args[1]):
default = ast_get_constant_value(node.args[1])
else:
default = ast.dump(node.args[1])
if default in ast_literals:
default = ast_literals[default]()
else:
default = None
log_debug(f"unsupported kwargs pop/get default: {ast_str(node)}")
return ParamData(
name=name,
annotation=inspect._empty,
default=default,
kind=kinds.KEYWORD_ONLY,
doc=doc_params.get(name),
parent=parent,
component=component,
origin=param_kwargs_get if node.func.attr == "get" else param_kwargs_pop,
)


def is_param_subclass_instance_default(param: ParamData) -> bool:
if is_dataclass_like(type(param.default)):
return False
Expand Down Expand Up @@ -368,31 +344,27 @@ def group_parameters(params_list: List[ParamList]) -> ParamList:
param.origin = None
return params_list[0]
grouped = []
params_count = 0
params_skip = set()
non_get_pop_count = 0
params_dict = defaultdict(lambda: [])
for params in params_list:
if params[0].origin not in {param_kwargs_get, param_kwargs_pop}:
params_count += 1
if not (params[0].origin or "").startswith(param_kwargs_pop_or_get): # type: ignore
non_get_pop_count += 1
for param in params:
if param.name not in params_skip and param.kind != kinds.POSITIONAL_ONLY:
if param.kind != kinds.POSITIONAL_ONLY:
params_dict[param.name].append(param)
if param.origin == param_kwargs_pop:
params_skip.add(param.name)
for params in params_dict.values():
gparam = params[0]
types = unique(p.annotation for p in params if p.annotation is not inspect._empty)
defaults = unique(p.default for p in params if p.default is not inspect._empty)
if len(params) >= params_count and len(types) <= 1 and len(defaults) <= 1:
if len(params) >= non_get_pop_count and len(types) <= 1 and len(defaults) <= 1:
gparam.origin = None
else:
gparam.parent = tuple(p.parent for p in params)
gparam.component = tuple(p.component for p in params)
gparam.origin = tuple(p.origin for p in params)
gparam.default = ConditionalDefault(
"ast-resolver",
(p.default for p in params) if len(defaults) > 1 else defaults,
)
if len(params) < non_get_pop_count:
defaults += ["NOT_ACCEPTED"]
gparam.default = ConditionalDefault("ast-resolver", defaults)
if len(types) > 1:
gparam.annotation = Union[tuple(types)] if types else inspect._empty
docs = [p.doc for p in params if p.doc]
Expand Down Expand Up @@ -644,6 +616,28 @@ def get_default_nodes(self, param_names: set):
default_nodes = [d for n, d in enumerate(default_nodes) if arg_nodes[n].arg in param_names]
return default_nodes

def get_kwargs_pop_or_get_parameter(self, node, component, parent, doc_params):
name = ast_get_constant_value(node.args[0])
if ast_is_constant(node.args[1]):
default = ast_get_constant_value(node.args[1])
else:
default = ast.dump(node.args[1])
if default in ast_literals:
default = ast_literals[default]()
else:
default = None
self.log_debug(f"unsupported kwargs pop/get default: {ast_str(node)}")
return ParamData(
name=name,
annotation=inspect._empty,
default=default,
kind=kinds.KEYWORD_ONLY,
doc=doc_params.get(name),
parent=parent,
component=component,
origin=param_kwargs_pop_or_get + self.get_node_origin(node),
)

def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]:
self.parse_source_tree()
args_name = getattr(self.component_node.args.vararg, "arg", None)
Expand All @@ -664,9 +658,7 @@ def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]:
for node in [v for k, v in values_found if k == kwargs_name]:
if isinstance(node, ast.Call):
if ast_is_kwargs_pop_or_get(node, kwargs_value_dump):
param = get_kwargs_pop_or_get_parameter(
node, self.component, self.parent, self.doc_params, self.log_debug
)
param = self.get_kwargs_pop_or_get_parameter(node, self.component, self.parent, self.doc_params)
params_list.append([param])
continue
kwarg = ast_get_call_kwarg_with_value(node, kwargs_value)
Expand Down Expand Up @@ -717,12 +709,15 @@ def get_parameters_attr_use_in_members(self, attr_name) -> ParamList:
self.log_debug(f"did not find use of {self.self_name}.{attr_name} in members of {self.parent}")
return []

def get_node_origin(self, node) -> str:
return f"{get_parameter_origins(self.component, self.parent)}:{node.lineno}"

def add_node_origins(self, params: ParamList, node) -> None:
origin = None
for param in params:
if param.origin is None:
if not origin:
origin = f"{get_parameter_origins(self.component, self.parent)}:{node.lineno}"
origin = self.get_node_origin(node)
param.origin = origin

def get_parameters_call_attr(self, attr_name: str, attr_value: ast.AST) -> Optional[ParamList]:
Expand Down
48 changes: 44 additions & 4 deletions jsonargparse_tests/test_parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import xml.dom
from calendar import Calendar
from random import shuffle
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Union
from unittest.mock import patch

import pytest

from jsonargparse import Namespace, class_from_function
from jsonargparse._optionals import docstring_parser_support
from jsonargparse._parameter_resolvers import ConditionalDefault, is_lambda
from jsonargparse._parameter_resolvers import get_signature_parameters as get_params
from jsonargparse._parameter_resolvers import is_lambda
from jsonargparse_tests.conftest import capture_logs, source_unavailable


Expand Down Expand Up @@ -383,6 +383,21 @@ def function_pop_get_from_kwargs(kn1: int = 0, **kw):
kw.pop("pk1", "")


def function_pop_get_conditional(p1: str, **kw):
"""
Args:
p1: help for p1
p2: help for p2
p3: help for p3
"""
kw.get("p3", "x")
if p1 == "a":
kw.pop("p2", None)
elif p1 == "b":
kw.pop("p2", 3)
kw.get("p3", "y")


def function_with_bug(**kws):
return does_not_exist(**kws) # noqa: F821

Expand Down Expand Up @@ -451,6 +466,7 @@ def assert_params(params, expected, origins={}):
assert expected == [p.name for p in params]
docs = [f"help for {p.name}" for p in params] if docstring_parser_support else [None] * len(params)
assert docs == [p.doc for p in params]
assert all(isinstance(params[n].default, ConditionalDefault) for n in origins.keys())
param_origins = {
n: [o.split(f"{__name__}.", 1)[1] for o in p.origin] for n, p in enumerate(params) if p.origin is not None
}
Expand Down Expand Up @@ -502,15 +518,20 @@ def test_get_params_class_with_kwargs_in_dict_attribute():


def test_get_params_class_kwargs_in_attr_method_conditioned_on_arg():
params = get_params(ClassG)
assert_params(
get_params(ClassG),
params,
["func", "kmg1", "kmg2", "kmg3", "kmg4"],
{
2: ["ClassG._run:3", "ClassG._run:5"],
3: ["ClassG._run:3", "ClassG._run:5"],
4: ["ClassG._run:5"],
},
)
assert params[2].annotation == Union[str, float]
assert str(params[2].default) == "Conditional<ast-resolver> {-, 2.3}"
assert str(params[3].default) == "Conditional<ast-resolver> {True, False}"
assert str(params[4].default) == "Conditional<ast-resolver> {4, NOT_ACCEPTED}"
with source_unavailable():
assert_params(get_params(ClassG), ["func"])

Expand Down Expand Up @@ -645,12 +666,31 @@ def test_get_params_function_call_classmethod():
def test_get_params_function_pop_get_from_kwargs(logger):
with capture_logs(logger) as logs:
params = get_params(function_pop_get_from_kwargs, logger=logger)
assert_params(params, ["kn1", "k2", "kn2", "kn3", "kn4", "pk1"])
assert str(params[1].default) == "Conditional<ast-resolver> {2, 1}"
assert_params(
params,
["kn1", "k2", "kn2", "kn3", "kn4", "pk1"],
{1: ["function_pop_get_from_kwargs:10", "function_pop_get_from_kwargs:15"]},
)
assert "unsupported kwargs pop/get default" in logs.getvalue()
with source_unavailable():
assert_params(get_params(function_pop_get_from_kwargs), ["kn1"])


def test_get_params_function_pop_get_conditional():
params = get_params(function_pop_get_conditional)
assert str(params[1].default) == "Conditional<ast-resolver> {x, y}"
assert str(params[2].default) == "Conditional<ast-resolver> {None, 3}"
assert_params(
params,
["p1", "p3", "p2"],
{
1: ["function_pop_get_conditional:8", "function_pop_get_conditional:13"],
2: ["function_pop_get_conditional:10", "function_pop_get_conditional:12"],
},
)


def test_get_params_function_module_class():
params = get_params(function_module_class)
assert ["firstweekday"] == [p.name for p in params]
Expand Down
2 changes: 1 addition & 1 deletion jsonargparse_tests/test_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def test_add_class_conditional_kwargs(parser):
"help for kmg1 (type: int, default: 1)",
"help for kmg2 (type: Union[str, float], default: Conditional<ast-resolver> {-, 2.3})",
"help for kmg3 (type: bool, default: Conditional<ast-resolver> {True, False})",
"help for kmg4 (type: int, default: Conditional<ast-resolver> 4)",
"help for kmg4 (type: int, default: Conditional<ast-resolver> {4, NOT_ACCEPTED})",
]
for value in expected:
assert value in help_str
Expand Down

0 comments on commit 4148f4e

Please sign in to comment.