diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index c9e23d520..7b5f3d6f6 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -341,11 +341,8 @@ def get_dataclass_data( is_optional, type_ = _resolve_optional(resolved_hints[field.name]) type_ = _resolve_forward(type_, obj.__module__) - if hasattr(obj, name): - value = getattr(obj, name) - if value == dataclasses.MISSING: - value = MISSING - else: + value = getattr(obj, name, MISSING) + if value in (MISSING, dataclasses.MISSING): if field.default_factory == dataclasses.MISSING: # type: ignore value = MISSING else: diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index dbd3ab933..ccb8b0811 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -556,3 +556,29 @@ class UntypedList: class UntypedDict: dict: Dict = {"foo": "var"} # type: ignore opt_dict: Optional[Dict] = None # type: ignore + + +class StructuredSubclass: + @attr.s(auto_attribs=True) + class ParentInts: + int1: int + int2: int + int3: int = attr.NOTHING # type: ignore + int4: int = MISSING + + @attr.s(auto_attribs=True) + class ChildInts(ParentInts): + int2: int = 5 + int3: int = 10 + int4: int = 15 + + @attr.s(auto_attribs=True) + class ParentContainers: + list1: List[int] = MISSING + list2: List[int] = [5, 6] + dict: Dict[str, Any] = MISSING + + @attr.s(auto_attribs=True) + class ChildContainers(ParentContainers): + list1: List[int] = [1, 2, 3] + dict: Dict[str, Any] = {"a": 5, "b": 6} diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index a5121fa34..146dd4814 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -577,3 +577,29 @@ class UntypedList: class UntypedDict: dict: Dict = field(default_factory=lambda: {"foo": "var"}) # type: ignore opt_dict: Optional[Dict] = None # type: ignore + + +class StructuredSubclass: + @dataclass + class ParentInts: + int1: int + int2: int + int3: int = dataclasses.MISSING # type: ignore + int4: int = MISSING + + @dataclass + class ChildInts(ParentInts): + int2: int = 5 + int3: int = 10 + int4: int = 15 + + @dataclass + class ParentContainers: + list1: List[int] = MISSING + list2: List[int] = field(default_factory=lambda: [5, 6]) + dict: Dict[str, Any] = MISSING + + @dataclass + class ChildContainers(ParentContainers): + list1: List[int] = field(default_factory=lambda: [1, 2, 3]) + dict: Dict[str, Any] = field(default_factory=lambda: {"a": 5, "b": 6}) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 2a7999dfd..2cabdff3f 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -1170,3 +1170,34 @@ def test_merge_missing_list_promotes_target_type(self, module: Any) -> None: c3 = OmegaConf.merge(c1, c2) with raises(ValidationError): c3.missing.append("xx") + + +class TestStructredConfigInheritance: + def test_leaf_node_inheritance(self, module: Any) -> None: + parent = OmegaConf.structured(module.StructuredSubclass.ParentInts) + child = OmegaConf.structured(module.StructuredSubclass.ChildInts) + + assert OmegaConf.is_missing(parent, "int1") + assert OmegaConf.is_missing(child, "int1") + + assert OmegaConf.is_missing(parent, "int2") + assert child.int2 == 5 + + assert OmegaConf.is_missing(parent, "int3") + assert child.int3 == 10 + + assert OmegaConf.is_missing(parent, "int4") + assert child.int4 == 15 + + def test_container_inheritance(self, module: Any) -> None: + parent = OmegaConf.structured(module.StructuredSubclass.ParentContainers) + child = OmegaConf.structured(module.StructuredSubclass.ChildContainers) + + assert OmegaConf.is_missing(parent, "list1") + assert child.list1 == [1, 2, 3] + + assert parent.list2 == [5, 6] + assert child.list2 == [5, 6] + + assert OmegaConf.is_missing(parent, "dict") + assert child.dict == {"a": 5, "b": 6}