Skip to content

Commit

Permalink
Improve Optional[type] validation (#749)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasha10 authored Nov 15, 2021
1 parent 87bd60a commit 55ce964
Show file tree
Hide file tree
Showing 15 changed files with 887 additions and 92 deletions.
1 change: 1 addition & 0 deletions news/460.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve support for Optional element types in structured config container type hints.
10 changes: 3 additions & 7 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,8 @@ def _is_optional(obj: Any, key: Optional[Union[int, str]] = None) -> bool:
if key is not None:
assert isinstance(obj, Container)
obj = obj._get_node(key)
if isinstance(obj, Node):
return obj._is_optional()
else:
# In case `obj` is not a Node, treat it as optional by default.
# This is used in `ListConfig.append` and `ListConfig.insert`
# where the appended/inserted value might or might not be a Node.
return True
assert isinstance(obj, Node)
return obj._is_optional()


def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]:
Expand Down Expand Up @@ -646,6 +641,7 @@ def get_dict_key_value_types(ref_type: Any) -> Tuple[Any, Any]:


def valid_value_annotation_type(type_: Any) -> bool:
_, type_ = _resolve_optional(type_)
return type_ is Any or is_primitive_type(type_) or is_structured_config(type_)


Expand Down
7 changes: 6 additions & 1 deletion omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
format_and_raise,
get_value_kind,
split_key,
type_str,
valid_value_annotation_type,
)
from .errors import (
ConfigKeyError,
Expand Down Expand Up @@ -73,7 +75,10 @@ def __post_init__(self) -> None:
self.ref_type = Any
assert self.key_type is Any or isinstance(self.key_type, type)
if self.element_type is not None:
assert self.element_type is Any or isinstance(self.element_type, type)
if not valid_value_annotation_type(self.element_type):
raise ValidationError(
f"Unsupported value type: '{type_str(self.element_type, include_module_name=True)}'"
)

if self.flags is None:
self.flags = {}
Expand Down
32 changes: 20 additions & 12 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_is_missing_literal,
_is_missing_value,
_is_none,
_is_union,
_resolve_optional,
get_ref_type,
get_structured_config_data,
Expand All @@ -32,6 +33,7 @@
ConfigTypeError,
InterpolationResolutionError,
MissingMandatoryValue,
OmegaConfBaseException,
ReadonlyConfigError,
ValidationError,
)
Expand Down Expand Up @@ -99,6 +101,12 @@ def __getstate__(self) -> Dict[str, Any]:
dict_copy["_metadata"].ref_type = List
else:
assert False
if sys.version_info < (3, 7): # pragma: no cover
element_type = self._metadata.element_type
if _is_union(element_type):
raise OmegaConfBaseException(
"Serializing structured configs with `Union` element type requires python >= 3.7"
)
return dict_copy

# Support pickle
Expand Down Expand Up @@ -342,13 +350,12 @@ def expand(node: Container) -> None:
dest[key] = target_node
dest_node = dest._get_node(key)

if (
dest_node is None
and is_structured_config(dest._metadata.element_type)
and not missing_src_value
):
is_optional, et = _resolve_optional(dest._metadata.element_type)
if dest_node is None and is_structured_config(et) and not missing_src_value:
# merging into a new node. Use element_type as a base
dest[key] = DictConfig(content=dest._metadata.element_type, parent=dest)
dest[key] = DictConfig(
et, parent=dest, ref_type=et, is_optional=is_optional
)
dest_node = dest._get_node(key)

if dest_node is not None:
Expand Down Expand Up @@ -420,9 +427,9 @@ def _list_merge(dest: Any, src: Any) -> None:
temp_target.__dict__["_metadata"] = copy.deepcopy(
dest.__dict__["_metadata"]
)
et = dest._metadata.element_type
is_optional, et = _resolve_optional(dest._metadata.element_type)
if is_structured_config(et):
prototype = OmegaConf.structured(et)
prototype = DictConfig(et, ref_type=et, is_optional=is_optional)
for item in src._iter_ex(resolve=False):
if isinstance(item, DictConfig):
item = OmegaConf.merge(prototype, item)
Expand Down Expand Up @@ -541,14 +548,14 @@ def _set_item_impl(self, key: Any, value: Any) -> None:
)

def wrap(key: Any, val: Any) -> Node:
is_optional = True
if not is_structured_config(val):
ref_type = self._metadata.element_type
is_optional, ref_type = _resolve_optional(self._metadata.element_type)
else:
target = self._get_node(key)
if target is None:
if is_structured_config(val):
ref_type = self._metadata.element_type
is_optional, ref_type = _resolve_optional(
self._metadata.element_type
)
else:
assert isinstance(target, Node)
is_optional = target._is_optional()
Expand All @@ -566,6 +573,7 @@ def assign(value_key: Any, val: ValueNode) -> None:
v = val
v._set_parent(self)
v._set_key(value_key)
_update_types(v, self._metadata.element_type, None)
self.__dict__["_content"][value_key] = v

if input_node and target_node:
Expand Down
34 changes: 16 additions & 18 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_is_missing_literal,
_is_missing_value,
_is_none,
_resolve_optional,
_valid_dict_key_annotation_type,
format_and_raise,
get_structured_config_data,
Expand All @@ -35,7 +36,6 @@
is_structured_config,
is_structured_config_frozen,
type_str,
valid_value_annotation_type,
)
from .base import Container, ContainerMetadata, DictKeyType, Node
from .basecontainer import BaseContainer
Expand Down Expand Up @@ -85,10 +85,6 @@ def __init__(
flags=flags,
),
)
if not valid_value_annotation_type(
element_type
) and not is_structured_config(element_type):
raise ValidationError(f"Unsupported value type: {element_type}")

if not _valid_dict_key_annotation_type(key_type):
raise KeyValidationError(f"Unsupported key type {key_type}")
Expand Down Expand Up @@ -240,25 +236,27 @@ def _validate_merge(self, value: Any) -> None:
)
raise ValidationError(msg)

def _validate_non_optional(self, key: Any, value: Any) -> None:
def _validate_non_optional(self, key: Optional[DictKeyType], value: Any) -> None:
if _is_none(value, resolve=True, throw_on_resolution_failure=False):

if key is not None:
child = self._get_node(key)
if child is not None:
assert isinstance(child, Node)
if not child._is_optional():
self._format_and_raise(
key=key,
value=value,
cause=ValidationError("child '$FULL_KEY' is not Optional"),
)
else:
if not self._is_optional():
self._format_and_raise(
key=None,
value=value,
cause=ValidationError("field '$FULL_KEY' is not Optional"),
field_is_optional = child._is_optional()
else:
field_is_optional, _ = _resolve_optional(
self._metadata.element_type
)
else:
field_is_optional = self._is_optional()

if not field_is_optional:
self._format_and_raise(
key=key,
value=value,
cause=ValidationError("field '$FULL_KEY' is not Optional"),
)

def _raise_invalid_value(
self, value: Any, value_type: Any, target_type: Any
Expand Down
36 changes: 16 additions & 20 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
ValueKind,
_is_missing_literal,
_is_none,
_is_optional,
_resolve_optional,
format_and_raise,
get_value_kind,
is_int,
is_primitive_list,
is_structured_config,
type_str,
valid_value_annotation_type,
)
from .base import Container, ContainerMetadata, Node
from .basecontainer import BaseContainer
Expand Down Expand Up @@ -70,10 +69,6 @@ def __init__(
flags=flags,
),
)
if not (valid_value_annotation_type(self._metadata.element_type)):
raise ValidationError(
f"Unsupported value type: {self._metadata.element_type}"
)

self.__dict__["_content"] = None
self._set_value(value=content, flags=flags)
Expand Down Expand Up @@ -103,19 +98,19 @@ def _validate_set(self, key: Any, value: Any) -> None:
"$FULL_KEY is not optional and cannot be assigned None"
)

target_type = self._metadata.element_type
is_optional, target_type = _resolve_optional(self._metadata.element_type)
value_type = OmegaConf.get_type(value)
if is_structured_config(target_type):
if (
target_type is not None
and value_type is not None
and not issubclass(value_type, target_type)
):
msg = (
f"Invalid type assigned: {type_str(value_type)} is not a "
f"subclass of {type_str(target_type)}. value: {value}"
)
raise ValidationError(msg)

if (value_type is None and not is_optional) or (
is_structured_config(target_type)
and value_type is not None
and not issubclass(value_type, target_type)
):
msg = (
f"Invalid type assigned: {type_str(value_type)} is not a "
f"subclass of {type_str(target_type)}. value: {value}"
)
raise ValidationError(msg)

def __deepcopy__(self, memo: Dict[int, Any]) -> "ListConfig":
res = ListConfig(None)
Expand Down Expand Up @@ -281,11 +276,12 @@ def insert(self, index: int, item: Any) -> None:
assert isinstance(self.__dict__["_content"], list)
# insert place holder
self.__dict__["_content"].insert(index, None)
is_optional, ref_type = _resolve_optional(self._metadata.element_type)
node = _maybe_wrap(
ref_type=self.__dict__["_metadata"].element_type,
ref_type=ref_type,
key=index,
value=item,
is_optional=_is_optional(item),
is_optional=is_optional,
parent=self,
)
self._validate_set(key=index, value=node)
Expand Down
19 changes: 19 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ class Users:
name2user: Dict[str, User] = field(default_factory=dict)


@dataclass
class OptionalUsers:
name2user: Dict[str, Optional[User]] = field(default_factory=dict)


@dataclass
class ConfWithMissingDict:
dict: Dict[str, Any] = MISSING
Expand Down Expand Up @@ -215,6 +220,12 @@ class SubscriptedList:
list: List[int] = field(default_factory=lambda: [1, 2])


@dataclass
class SubscriptedListOpt:
opt_list: Optional[List[int]] = field(default_factory=lambda: [1, 2])
list_opt: List[Optional[int]] = field(default_factory=lambda: [1, 2, None])


@dataclass
class UntypedDict:
dict: Dict = field(default_factory=lambda: {"foo": "var"}) # type: ignore
Expand All @@ -230,6 +241,14 @@ class SubscriptedDict:
dict_bool: Dict[bool, int] = field(default_factory=lambda: {True: 4, False: 5})


@dataclass
class SubscriptedDictOpt:
opt_dict: Optional[Dict[str, int]] = field(default_factory=lambda: {"foo": 4})
dict_opt: Dict[str, Optional[int]] = field(
default_factory=lambda: {"foo": 4, "bar": None}
)


@dataclass
class InterpolationList:
list: List[float] = II("optimization.lr")
Expand Down
2 changes: 1 addition & 1 deletion tests/structured_conf/test_structured_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_merge_error_override_bad_type(self, module: Any) -> None:

def test_error_message(self, module: Any) -> None:
cfg = OmegaConf.structured(module.StructuredOptional)
msg = re.escape("child 'not_optional' is not Optional")
msg = re.escape("field 'not_optional' is not Optional")
with raises(ValidationError, match=msg):
cfg.not_optional = None

Expand Down
15 changes: 3 additions & 12 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,10 +808,7 @@ def test_set_list_correct_type(self, module: Any) -> None:
3.1415,
["foo", True, 1.2],
User(),
param(
[None],
marks=mark.xfail, # https://github.com/omry/omegaconf/issues/579
),
param([None]),
],
)
def test_assign_wrong_type_to_list(self, module: Any, value: Any) -> None:
Expand All @@ -825,10 +822,7 @@ def test_assign_wrong_type_to_list(self, module: Any, value: Any) -> None:
@mark.parametrize(
"value",
[
param(
None,
marks=mark.xfail, # https://github.com/omry/omegaconf/issues/579
),
param(None),
True,
"str",
3.1415,
Expand All @@ -849,10 +843,7 @@ def test_insert_wrong_type_to_list(self, module: Any, value: Any) -> None:
3.1415,
["foo", True, 1.2],
{"foo": True},
param(
{"foo": None},
marks=mark.xfail, # expected failure, https://github.com/omry/omegaconf/issues/579
),
param({"foo": None}),
User(age=1, name="foo"),
{"user": User(age=1, name="foo")},
ListConfig(content=[1, 2], ref_type=List[int], element_type=int),
Expand Down
Loading

0 comments on commit 55ce964

Please sign in to comment.