Skip to content

Commit

Permalink
Fix dataclass nested in a list not setting defaults (#357).
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa committed Aug 31, 2023
1 parent 323b4a5 commit cd165e9
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Fixed
^^^^^
- Remove private ``linked_targets`` parameter from API Reference (`#317
<https://github.com/omni-us/jsonargparse/issues/317>`__).
- Dataclass nested in list not setting defaults (`#357
<https://github.com/omni-us/jsonargparse/issues/357>`__)


v4.24.0 (2023-08-23)
Expand Down
5 changes: 3 additions & 2 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ def adapt_typehints(
instantiate_classes=False,
prev_val=None,
append=False,
list_item=False,
enable_path=False,
sub_add_kwargs=None,
):
Expand Down Expand Up @@ -717,7 +718,7 @@ def adapt_typehints(
for n, v in enumerate(val):
if isinstance(prev_val, list) and len(prev_val) == len(val):
adapt_kwargs_n = {**adapt_kwargs, "prev_val": prev_val[n]}
val[n] = adapt_typehints(v, subtypehints[0], **adapt_kwargs_n)
val[n] = adapt_typehints(v, subtypehints[0], list_item=True, **adapt_kwargs_n)

# Dict, Mapping
elif typehint_origin in mapping_origin_types:
Expand Down Expand Up @@ -802,7 +803,7 @@ def adapt_typehints(
if serialize:
val = load_value(parser.dump(val, **dump_kwargs.get()))
elif isinstance(val, (dict, Namespace)):
val = parser.parse_object(val, defaults=sub_defaults.get())
val = parser.parse_object(val, defaults=sub_defaults.get() or list_item)
elif isinstance(val, NestedArg):
val = parser.parse_args([f"--{val.key}={val.val}"])
else:
Expand Down
17 changes: 17 additions & 0 deletions jsonargparse_tests/test_dataclass_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,23 @@ def test_add_dataclass_arguments(parser, subtests):
parser.add_dataclass_arguments(MixedClass, "c")


@dataclasses.dataclass
class NestedDefaultsA:
x: list = dataclasses.field(default_factory=list)
v: int = 1


@dataclasses.dataclass
class NestedDefaultsB:
a: List[NestedDefaultsA]


def test_add_dataclass_nested_defaults(parser):
parser.add_dataclass_arguments(NestedDefaultsB, "data")
cfg = parser.parse_args(["--data.a=[{}]"])
assert cfg.data == Namespace(a=[Namespace(x=[], v=1)])


class ClassDataAttributes:
def __init__(
self,
Expand Down

0 comments on commit cd165e9

Please sign in to comment.