Skip to content

Commit

Permalink
Fix default_factory support in structured config subclasses (#826)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasha10 authored Nov 23, 2021
1 parent 55ce964 commit f362b5c
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 5 deletions.
7 changes: 2 additions & 5 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,8 @@ def get_dataclass_data(
is_optional, type_ = _resolve_optional(resolved_hints[field.name])
type_ = _resolve_forward(type_, obj.__module__)

if hasattr(obj, name):
value = getattr(obj, name)
if value == dataclasses.MISSING:
value = MISSING
else:
value = getattr(obj, name, MISSING)
if value in (MISSING, dataclasses.MISSING):
if field.default_factory == dataclasses.MISSING: # type: ignore
value = MISSING
else:
Expand Down
26 changes: 26 additions & 0 deletions tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,29 @@ class UntypedList:
class UntypedDict:
dict: Dict = {"foo": "var"} # type: ignore
opt_dict: Optional[Dict] = None # type: ignore


class StructuredSubclass:
@attr.s(auto_attribs=True)
class ParentInts:
int1: int
int2: int
int3: int = attr.NOTHING # type: ignore
int4: int = MISSING

@attr.s(auto_attribs=True)
class ChildInts(ParentInts):
int2: int = 5
int3: int = 10
int4: int = 15

@attr.s(auto_attribs=True)
class ParentContainers:
list1: List[int] = MISSING
list2: List[int] = [5, 6]
dict: Dict[str, Any] = MISSING

@attr.s(auto_attribs=True)
class ChildContainers(ParentContainers):
list1: List[int] = [1, 2, 3]
dict: Dict[str, Any] = {"a": 5, "b": 6}
26 changes: 26 additions & 0 deletions tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,3 +577,29 @@ class UntypedList:
class UntypedDict:
dict: Dict = field(default_factory=lambda: {"foo": "var"}) # type: ignore
opt_dict: Optional[Dict] = None # type: ignore


class StructuredSubclass:
@dataclass
class ParentInts:
int1: int
int2: int
int3: int = dataclasses.MISSING # type: ignore
int4: int = MISSING

@dataclass
class ChildInts(ParentInts):
int2: int = 5
int3: int = 10
int4: int = 15

@dataclass
class ParentContainers:
list1: List[int] = MISSING
list2: List[int] = field(default_factory=lambda: [5, 6])
dict: Dict[str, Any] = MISSING

@dataclass
class ChildContainers(ParentContainers):
list1: List[int] = field(default_factory=lambda: [1, 2, 3])
dict: Dict[str, Any] = field(default_factory=lambda: {"a": 5, "b": 6})
31 changes: 31 additions & 0 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,3 +1170,34 @@ def test_merge_missing_list_promotes_target_type(self, module: Any) -> None:
c3 = OmegaConf.merge(c1, c2)
with raises(ValidationError):
c3.missing.append("xx")


class TestStructredConfigInheritance:
def test_leaf_node_inheritance(self, module: Any) -> None:
parent = OmegaConf.structured(module.StructuredSubclass.ParentInts)
child = OmegaConf.structured(module.StructuredSubclass.ChildInts)

assert OmegaConf.is_missing(parent, "int1")
assert OmegaConf.is_missing(child, "int1")

assert OmegaConf.is_missing(parent, "int2")
assert child.int2 == 5

assert OmegaConf.is_missing(parent, "int3")
assert child.int3 == 10

assert OmegaConf.is_missing(parent, "int4")
assert child.int4 == 15

def test_container_inheritance(self, module: Any) -> None:
parent = OmegaConf.structured(module.StructuredSubclass.ParentContainers)
child = OmegaConf.structured(module.StructuredSubclass.ChildContainers)

assert OmegaConf.is_missing(parent, "list1")
assert child.list1 == [1, 2, 3]

assert parent.list2 == [5, 6]
assert child.list2 == [5, 6]

assert OmegaConf.is_missing(parent, "dict")
assert child.dict == {"a": 5, "b": 6}

0 comments on commit f362b5c

Please sign in to comment.