Skip to content

Commit

Permalink
Modify get_dataclass_data to fix omry#830 and omry#831
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasha10 committed Apr 13, 2022
1 parent 5c71703 commit f787425
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 10 deletions.
1 change: 1 addition & 0 deletions news/831.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bugs related to creation of structured configs from dataclasses having fields with a default_factory
16 changes: 11 additions & 5 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def get_dataclass_data(

flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
d = {}
is_type = isinstance(obj, type)
obj_type = get_type_of(obj)
dummy_parent = OmegaConf.create({}, flags=flags)
dummy_parent._metadata.object_type = obj_type
Expand All @@ -344,13 +345,18 @@ def get_dataclass_data(
name = field.name
is_optional, type_ = _resolve_optional(resolved_hints[field.name])
type_ = _resolve_forward(type_, obj.__module__)
has_default = field.default != dataclasses.MISSING
has_default_factory = field.default_factory != dataclasses.MISSING # type: ignore

value = getattr(obj, name, MISSING)
if value in (MISSING, dataclasses.MISSING):
if field.default_factory == dataclasses.MISSING: # type: ignore
value = MISSING
else:
if not is_type:
value = getattr(obj, name)
else:
if has_default:
value = field.default
elif has_default_factory:
value = field.default_factory() # type: ignore
else:
value = MISSING

if _is_union(type_):
e = ConfigValueError(
Expand Down
10 changes: 10 additions & 0 deletions tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,16 @@ class ChildContainers(ParentContainers):
list1: List[int] = [1, 2, 3]
dict: Dict[str, Any] = {"a": 5, "b": 6}

@attr.s(auto_attribs=True)
class ParentNoDefaultFactory:
no_default_to_list: Any
int_to_list: Any = 1

@attr.s(auto_attribs=True)
class ChildWithDefaultFactory(ParentNoDefaultFactory):
no_default_to_list: Any = ["hi"]
int_to_list: Any = ["hi"]


@attr.s(auto_attribs=True)
class HasInitFalseFields:
Expand Down
10 changes: 10 additions & 0 deletions tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,16 @@ class ChildContainers(ParentContainers):
list1: List[int] = field(default_factory=lambda: [1, 2, 3])
dict: Dict[str, Any] = field(default_factory=lambda: {"a": 5, "b": 6})

@dataclass
class ParentNoDefaultFactory:
no_default_to_list: Any
int_to_list: Any = 1

@dataclass
class ChildWithDefaultFactory(ParentNoDefaultFactory):
no_default_to_list: Any = field(default_factory=lambda: ["hi"])
int_to_list: Any = field(default_factory=lambda: ["hi"])


@dataclass
class HasInitFalseFields:
Expand Down
43 changes: 38 additions & 5 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import sys
from importlib import import_module
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

from pytest import fixture, mark, param, raises

Expand Down Expand Up @@ -208,8 +208,15 @@ def validate(cfg: DictConfig) -> None:
conf1 = OmegaConf.structured(module.ConfigWithList)
validate(conf1)

conf1 = OmegaConf.structured(module.ConfigWithList())
validate(conf1)
conf2 = OmegaConf.structured(module.ConfigWithList())
validate(conf2)

def test_config_with_list_nondefault_values(self, module: Any) -> None:
conf1 = OmegaConf.structured(module.ConfigWithList(list1=[4, 5, 6]))
assert conf1.list1 == [4, 5, 6]

conf2 = OmegaConf.structured(module.ConfigWithList(list1=MISSING))
assert OmegaConf.is_missing(conf2, "list1")

def test_assignment_to_nested_structured_config(self, module: Any) -> None:
conf = OmegaConf.structured(module.NestedConfig)
Expand All @@ -232,8 +239,16 @@ def validate(cfg: DictConfig) -> None:

conf1 = OmegaConf.structured(module.ConfigWithDict)
validate(conf1)
conf1 = OmegaConf.structured(module.ConfigWithDict())
validate(conf1)

conf2 = OmegaConf.structured(module.ConfigWithDict())
validate(conf2)

def test_config_with_dict_nondefault_values(self, module: Any) -> None:
conf1 = OmegaConf.structured(module.ConfigWithDict(dict1={"baz": "qux"}))
assert conf1.dict1 == {"baz": "qux"}

conf2 = OmegaConf.structured(module.ConfigWithDict(dict1=MISSING))
assert OmegaConf.is_missing(conf2, "dict1")

def test_structured_config_struct_behavior(self, module: Any) -> None:
def validate(cfg: DictConfig) -> None:
Expand Down Expand Up @@ -1230,6 +1245,24 @@ def test_container_inheritance(self, module: Any) -> None:
assert OmegaConf.is_missing(parent, "dict")
assert child.dict == {"a": 5, "b": 6}

@mark.parametrize(
"create_fn",
[
param(lambda cls: OmegaConf.structured(cls), id="create_from_class"),
param(lambda cls: OmegaConf.structured(cls()), id="create_from_instance"),
],
)
def test_subclass_using_default_factory(
self, module: Any, create_fn: Callable[[Any], DictConfig]
) -> None:
"""
When a structured config class field has a default and a subclass defines a default_factory for the same field,
ensure that the DictConfig object created from the subclass uses the default_factory.
"""
cfg = create_fn(module.StructuredSubclass.ChildWithDefaultFactory)
assert cfg.no_default_to_list == ["hi"]
assert cfg.int_to_list == ["hi"]


class TestNestedContainers:
@mark.parametrize(
Expand Down

0 comments on commit f787425

Please sign in to comment.