Skip to content

Commit

Permalink
Allow assignment of MISSING in List[Structured] (omry#828)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasha10 authored Nov 30, 2021
1 parent f362b5c commit 45ff30b
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 13 deletions.
30 changes: 17 additions & 13 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,23 @@ def _validate_set(self, key: Any, value: Any) -> None:
"$FULL_KEY is not optional and cannot be assigned None"
)

is_optional, target_type = _resolve_optional(self._metadata.element_type)
value_type = OmegaConf.get_type(value)

if (value_type is None and not is_optional) or (
is_structured_config(target_type)
and value_type is not None
and not issubclass(value_type, target_type)
):
msg = (
f"Invalid type assigned: {type_str(value_type)} is not a "
f"subclass of {type_str(target_type)}. value: {value}"
)
raise ValidationError(msg)
vk = get_value_kind(value)
if vk == ValueKind.MANDATORY_MISSING:
return
else:
is_optional, target_type = _resolve_optional(self._metadata.element_type)
value_type = OmegaConf.get_type(value)

if (value_type is None and not is_optional) or (
is_structured_config(target_type)
and value_type is not None
and not issubclass(value_type, target_type)
):
msg = (
f"Invalid type assigned: {type_str(value_type)} is not a "
f"subclass of {type_str(target_type)}. value: {value}"
)
raise ValidationError(msg)

def __deepcopy__(self, memo: Dict[int, Any]) -> "ListConfig":
res = ListConfig(None)
Expand Down
10 changes: 10 additions & 0 deletions tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,21 @@ class DictOfObjects:
users: Dict[str, User] = {"joe": User(name="Joe", age=18)}


@attr.s(auto_attribs=True)
class DictOfObjectsMissing:
users: Dict[str, User] = {"moe": MISSING}


@attr.s(auto_attribs=True)
class ListOfObjects:
users: List[User] = [User(name="Joe", age=18)]


@attr.s(auto_attribs=True)
class ListOfObjectsMissing:
users: List[User] = [MISSING]


class DictSubclass:
@attr.s(auto_attribs=True)
class Str2Str(Dict[str, str]):
Expand Down
10 changes: 10 additions & 0 deletions tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,21 @@ class DictOfObjects:
)


@dataclass
class DictOfObjectsMissing:
users: Dict[str, User] = field(default_factory=lambda: {"moe": MISSING})


@dataclass
class ListOfObjects:
users: List[User] = field(default_factory=lambda: [User(name="Joe", age=18)])


@dataclass
class ListOfObjectsMissing:
users: List[User] = field(default_factory=lambda: [MISSING])


class DictSubclass:
@dataclass
class Str2Str(Dict[str, str]):
Expand Down
27 changes: 27 additions & 0 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,20 @@ def test_dict_of_objects(self, module: Any) -> None:
with raises(ValidationError):
dct.fail = "fail"

def test_dict_of_objects_missing(self, module: Any) -> None:
conf = OmegaConf.structured(module.DictOfObjectsMissing)
dct = conf.users

assert OmegaConf.is_missing(dct, "moe")

dct.miss = MISSING
assert OmegaConf.is_missing(dct, "miss")

def test_assign_dict_of_objects(self, module: Any) -> None:
conf = OmegaConf.structured(module.DictOfObjects)
conf.users = {"poe": module.User(name="Poe", age=8), "miss": MISSING}
assert conf.users == {"poe": {"name": "Poe", "age": 8}, "miss": "???"}

def test_list_of_objects(self, module: Any) -> None:
conf = OmegaConf.structured(module.ListOfObjects)
assert conf.users[0].age == 18
Expand All @@ -758,6 +772,19 @@ def test_list_of_objects(self, module: Any) -> None:
with raises(ValidationError):
conf.users.append("fail")

def test_list_of_objects_missing(self, module: Any) -> None:
conf = OmegaConf.structured(module.ListOfObjectsMissing)

assert OmegaConf.is_missing(conf.users, 0)

conf.users.append(MISSING)
assert OmegaConf.is_missing(conf.users, 1)

def test_assign_list_of_objects(self, module: Any) -> None:
conf = OmegaConf.structured(module.ListOfObjects)
conf.users = [module.User(name="Poe", age=8), MISSING]
assert conf.users == [{"name": "Poe", "age": 8}, "???"]

def test_promote_api(self, module: Any) -> None:
conf = OmegaConf.create(module.AnyTypeConfig)
conf._promote(None)
Expand Down

0 comments on commit 45ff30b

Please sign in to comment.