Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_dataclass_data: branch on dataclass vs dataclass instance #832

Merged
merged 1 commit into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions news/831.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bugs related to creation of structured configs from dataclasses having fields with a default_factory
16 changes: 11 additions & 5 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def get_dataclass_data(

flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
d = {}
is_type = isinstance(obj, type)
obj_type = get_type_of(obj)
dummy_parent = OmegaConf.create({}, flags=flags)
dummy_parent._metadata.object_type = obj_type
Expand All @@ -344,13 +345,18 @@ def get_dataclass_data(
name = field.name
is_optional, type_ = _resolve_optional(resolved_hints[field.name])
type_ = _resolve_forward(type_, obj.__module__)
has_default = field.default != dataclasses.MISSING
has_default_factory = field.default_factory != dataclasses.MISSING # type: ignore

value = getattr(obj, name, MISSING)
if value in (MISSING, dataclasses.MISSING):
if field.default_factory == dataclasses.MISSING: # type: ignore
value = MISSING
else:
if not is_type:
value = getattr(obj, name)
else:
if has_default:
value = field.default
elif has_default_factory:
value = field.default_factory() # type: ignore
else:
value = MISSING

if _is_union(type_):
e = ConfigValueError(
Expand Down
10 changes: 10 additions & 0 deletions tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,16 @@ class ChildContainers(ParentContainers):
list1: List[int] = [1, 2, 3]
dict: Dict[str, Any] = {"a": 5, "b": 6}

@attr.s(auto_attribs=True)
class ParentNoDefaultFactory:
no_default_to_list: Any
int_to_list: Any = 1

@attr.s(auto_attribs=True)
class ChildWithDefaultFactory(ParentNoDefaultFactory):
no_default_to_list: Any = ["hi"]
int_to_list: Any = ["hi"]


@attr.s(auto_attribs=True)
class HasInitFalseFields:
Expand Down
10 changes: 10 additions & 0 deletions tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,16 @@ class ChildContainers(ParentContainers):
list1: List[int] = field(default_factory=lambda: [1, 2, 3])
dict: Dict[str, Any] = field(default_factory=lambda: {"a": 5, "b": 6})

@dataclass
class ParentNoDefaultFactory:
no_default_to_list: Any
int_to_list: Any = 1

@dataclass
class ChildWithDefaultFactory(ParentNoDefaultFactory):
no_default_to_list: Any = field(default_factory=lambda: ["hi"])
int_to_list: Any = field(default_factory=lambda: ["hi"])


@dataclass
class HasInitFalseFields:
Expand Down
36 changes: 35 additions & 1 deletion tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import sys
from importlib import import_module
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

from pytest import fixture, mark, param, raises

Expand Down Expand Up @@ -211,6 +211,13 @@ def validate(cfg: DictConfig) -> None:
conf2 = OmegaConf.structured(module.ConfigWithList())
validate(conf2)

def test_config_with_list_nondefault_values(self, module: Any) -> None:
conf1 = OmegaConf.structured(module.ConfigWithList(list1=[4, 5, 6]))
assert conf1.list1 == [4, 5, 6]

conf2 = OmegaConf.structured(module.ConfigWithList(list1=MISSING))
assert OmegaConf.is_missing(conf2, "list1")

def test_assignment_to_nested_structured_config(self, module: Any) -> None:
conf = OmegaConf.structured(module.NestedConfig)
with raises(ValidationError):
Expand All @@ -236,6 +243,13 @@ def validate(cfg: DictConfig) -> None:
conf2 = OmegaConf.structured(module.ConfigWithDict())
validate(conf2)

def test_config_with_dict_nondefault_values(self, module: Any) -> None:
conf1 = OmegaConf.structured(module.ConfigWithDict(dict1={"baz": "qux"}))
assert conf1.dict1 == {"baz": "qux"}

conf2 = OmegaConf.structured(module.ConfigWithDict(dict1=MISSING))
assert OmegaConf.is_missing(conf2, "dict1")

def test_structured_config_struct_behavior(self, module: Any) -> None:
def validate(cfg: DictConfig) -> None:
assert not OmegaConf.is_struct(cfg)
Expand Down Expand Up @@ -1231,6 +1245,26 @@ def test_container_inheritance(self, module: Any) -> None:
assert OmegaConf.is_missing(parent, "dict")
assert child.dict == {"a": 5, "b": 6}

@mark.parametrize(
"create_fn",
[
param(lambda cls: OmegaConf.structured(cls), id="create_from_class"),
param(lambda cls: OmegaConf.structured(cls()), id="create_from_instance"),
],
)
def test_subclass_using_default_factory(
self, module: Any, create_fn: Callable[[Any], DictConfig]
) -> None:
"""
When a structured config field has a default and a subclass defines a
default_factory for the same field, ensure that the DictConfig created
from the subclass uses the subclass' default_factory (not the parent
class' default).
"""
cfg = create_fn(module.StructuredSubclass.ChildWithDefaultFactory)
assert cfg.no_default_to_list == ["hi"]
assert cfg.int_to_list == ["hi"]


class TestNestedContainers:
@mark.parametrize(
Expand Down