diff --git a/news/460.feature b/news/460.feature new file mode 100644 index 000000000..054982481 --- /dev/null +++ b/news/460.feature @@ -0,0 +1 @@ +Improve support for Optional element types in structured config container type hints. diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index d82e78ded..c9e23d520 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -199,13 +199,8 @@ def _is_optional(obj: Any, key: Optional[Union[int, str]] = None) -> bool: if key is not None: assert isinstance(obj, Container) obj = obj._get_node(key) - if isinstance(obj, Node): - return obj._is_optional() - else: - # In case `obj` is not a Node, treat it as optional by default. - # This is used in `ListConfig.append` and `ListConfig.insert` - # where the appended/inserted value might or might not be a Node. - return True + assert isinstance(obj, Node) + return obj._is_optional() def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]: @@ -646,6 +641,7 @@ def get_dict_key_value_types(ref_type: Any) -> Tuple[Any, Any]: def valid_value_annotation_type(type_: Any) -> bool: + _, type_ = _resolve_optional(type_) return type_ is Any or is_primitive_type(type_) or is_structured_config(type_) diff --git a/omegaconf/base.py b/omegaconf/base.py index 6c1db7351..2bbc50ffd 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -16,6 +16,8 @@ format_and_raise, get_value_kind, split_key, + type_str, + valid_value_annotation_type, ) from .errors import ( ConfigKeyError, @@ -73,7 +75,10 @@ def __post_init__(self) -> None: self.ref_type = Any assert self.key_type is Any or isinstance(self.key_type, type) if self.element_type is not None: - assert self.element_type is Any or isinstance(self.element_type, type) + if not valid_value_annotation_type(self.element_type): + raise ValidationError( + f"Unsupported value type: '{type_str(self.element_type, include_module_name=True)}'" + ) if self.flags is None: self.flags = {} diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 78daec2c1..3db7a6c5a 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -14,6 +14,7 @@ _is_missing_literal, _is_missing_value, _is_none, + _is_union, _resolve_optional, get_ref_type, get_structured_config_data, @@ -32,6 +33,7 @@ ConfigTypeError, InterpolationResolutionError, MissingMandatoryValue, + OmegaConfBaseException, ReadonlyConfigError, ValidationError, ) @@ -99,6 +101,12 @@ def __getstate__(self) -> Dict[str, Any]: dict_copy["_metadata"].ref_type = List else: assert False + if sys.version_info < (3, 7): # pragma: no cover + element_type = self._metadata.element_type + if _is_union(element_type): + raise OmegaConfBaseException( + "Serializing structured configs with `Union` element type requires python >= 3.7" + ) return dict_copy # Support pickle @@ -342,13 +350,12 @@ def expand(node: Container) -> None: dest[key] = target_node dest_node = dest._get_node(key) - if ( - dest_node is None - and is_structured_config(dest._metadata.element_type) - and not missing_src_value - ): + is_optional, et = _resolve_optional(dest._metadata.element_type) + if dest_node is None and is_structured_config(et) and not missing_src_value: # merging into a new node. Use element_type as a base - dest[key] = DictConfig(content=dest._metadata.element_type, parent=dest) + dest[key] = DictConfig( + et, parent=dest, ref_type=et, is_optional=is_optional + ) dest_node = dest._get_node(key) if dest_node is not None: @@ -420,9 +427,9 @@ def _list_merge(dest: Any, src: Any) -> None: temp_target.__dict__["_metadata"] = copy.deepcopy( dest.__dict__["_metadata"] ) - et = dest._metadata.element_type + is_optional, et = _resolve_optional(dest._metadata.element_type) if is_structured_config(et): - prototype = OmegaConf.structured(et) + prototype = DictConfig(et, ref_type=et, is_optional=is_optional) for item in src._iter_ex(resolve=False): if isinstance(item, DictConfig): item = OmegaConf.merge(prototype, item) @@ -541,14 +548,14 @@ def _set_item_impl(self, key: Any, value: Any) -> None: ) def wrap(key: Any, val: Any) -> Node: - is_optional = True if not is_structured_config(val): - ref_type = self._metadata.element_type + is_optional, ref_type = _resolve_optional(self._metadata.element_type) else: target = self._get_node(key) if target is None: - if is_structured_config(val): - ref_type = self._metadata.element_type + is_optional, ref_type = _resolve_optional( + self._metadata.element_type + ) else: assert isinstance(target, Node) is_optional = target._is_optional() @@ -566,6 +573,7 @@ def assign(value_key: Any, val: ValueNode) -> None: v = val v._set_parent(self) v._set_key(value_key) + _update_types(v, self._metadata.element_type, None) self.__dict__["_content"][value_key] = v if input_node and target_node: diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 41828c1bc..e2901fe7c 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -23,6 +23,7 @@ _is_missing_literal, _is_missing_value, _is_none, + _resolve_optional, _valid_dict_key_annotation_type, format_and_raise, get_structured_config_data, @@ -35,7 +36,6 @@ is_structured_config, is_structured_config_frozen, type_str, - valid_value_annotation_type, ) from .base import Container, ContainerMetadata, DictKeyType, Node from .basecontainer import BaseContainer @@ -85,10 +85,6 @@ def __init__( flags=flags, ), ) - if not valid_value_annotation_type( - element_type - ) and not is_structured_config(element_type): - raise ValidationError(f"Unsupported value type: {element_type}") if not _valid_dict_key_annotation_type(key_type): raise KeyValidationError(f"Unsupported key type {key_type}") @@ -240,25 +236,27 @@ def _validate_merge(self, value: Any) -> None: ) raise ValidationError(msg) - def _validate_non_optional(self, key: Any, value: Any) -> None: + def _validate_non_optional(self, key: Optional[DictKeyType], value: Any) -> None: if _is_none(value, resolve=True, throw_on_resolution_failure=False): + if key is not None: child = self._get_node(key) if child is not None: assert isinstance(child, Node) - if not child._is_optional(): - self._format_and_raise( - key=key, - value=value, - cause=ValidationError("child '$FULL_KEY' is not Optional"), - ) - else: - if not self._is_optional(): - self._format_and_raise( - key=None, - value=value, - cause=ValidationError("field '$FULL_KEY' is not Optional"), + field_is_optional = child._is_optional() + else: + field_is_optional, _ = _resolve_optional( + self._metadata.element_type ) + else: + field_is_optional = self._is_optional() + + if not field_is_optional: + self._format_and_raise( + key=key, + value=value, + cause=ValidationError("field '$FULL_KEY' is not Optional"), + ) def _raise_invalid_value( self, value: Any, value_type: Any, target_type: Any diff --git a/omegaconf/listconfig.py b/omegaconf/listconfig.py index 9d81aec3a..f98e8366f 100644 --- a/omegaconf/listconfig.py +++ b/omegaconf/listconfig.py @@ -18,14 +18,13 @@ ValueKind, _is_missing_literal, _is_none, - _is_optional, + _resolve_optional, format_and_raise, get_value_kind, is_int, is_primitive_list, is_structured_config, type_str, - valid_value_annotation_type, ) from .base import Container, ContainerMetadata, Node from .basecontainer import BaseContainer @@ -70,10 +69,6 @@ def __init__( flags=flags, ), ) - if not (valid_value_annotation_type(self._metadata.element_type)): - raise ValidationError( - f"Unsupported value type: {self._metadata.element_type}" - ) self.__dict__["_content"] = None self._set_value(value=content, flags=flags) @@ -103,19 +98,19 @@ def _validate_set(self, key: Any, value: Any) -> None: "$FULL_KEY is not optional and cannot be assigned None" ) - target_type = self._metadata.element_type + is_optional, target_type = _resolve_optional(self._metadata.element_type) value_type = OmegaConf.get_type(value) - if is_structured_config(target_type): - if ( - target_type is not None - and value_type is not None - and not issubclass(value_type, target_type) - ): - msg = ( - f"Invalid type assigned: {type_str(value_type)} is not a " - f"subclass of {type_str(target_type)}. value: {value}" - ) - raise ValidationError(msg) + + if (value_type is None and not is_optional) or ( + is_structured_config(target_type) + and value_type is not None + and not issubclass(value_type, target_type) + ): + msg = ( + f"Invalid type assigned: {type_str(value_type)} is not a " + f"subclass of {type_str(target_type)}. value: {value}" + ) + raise ValidationError(msg) def __deepcopy__(self, memo: Dict[int, Any]) -> "ListConfig": res = ListConfig(None) @@ -281,11 +276,12 @@ def insert(self, index: int, item: Any) -> None: assert isinstance(self.__dict__["_content"], list) # insert place holder self.__dict__["_content"].insert(index, None) + is_optional, ref_type = _resolve_optional(self._metadata.element_type) node = _maybe_wrap( - ref_type=self.__dict__["_metadata"].element_type, + ref_type=ref_type, key=index, value=item, - is_optional=_is_optional(item), + is_optional=is_optional, parent=self, ) self._validate_set(key=index, value=node) diff --git a/tests/__init__.py b/tests/__init__.py index 9e91c0301..f433d703b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -77,6 +77,11 @@ class Users: name2user: Dict[str, User] = field(default_factory=dict) +@dataclass +class OptionalUsers: + name2user: Dict[str, Optional[User]] = field(default_factory=dict) + + @dataclass class ConfWithMissingDict: dict: Dict[str, Any] = MISSING @@ -215,6 +220,12 @@ class SubscriptedList: list: List[int] = field(default_factory=lambda: [1, 2]) +@dataclass +class SubscriptedListOpt: + opt_list: Optional[List[int]] = field(default_factory=lambda: [1, 2]) + list_opt: List[Optional[int]] = field(default_factory=lambda: [1, 2, None]) + + @dataclass class UntypedDict: dict: Dict = field(default_factory=lambda: {"foo": "var"}) # type: ignore @@ -230,6 +241,14 @@ class SubscriptedDict: dict_bool: Dict[bool, int] = field(default_factory=lambda: {True: 4, False: 5}) +@dataclass +class SubscriptedDictOpt: + opt_dict: Optional[Dict[str, int]] = field(default_factory=lambda: {"foo": 4}) + dict_opt: Dict[str, Optional[int]] = field( + default_factory=lambda: {"foo": 4, "bar": None} + ) + + @dataclass class InterpolationList: list: List[float] = II("optimization.lr") diff --git a/tests/structured_conf/test_structured_basic.py b/tests/structured_conf/test_structured_basic.py index 7fe3b3887..5d0d27af2 100644 --- a/tests/structured_conf/test_structured_basic.py +++ b/tests/structured_conf/test_structured_basic.py @@ -106,7 +106,7 @@ def test_merge_error_override_bad_type(self, module: Any) -> None: def test_error_message(self, module: Any) -> None: cfg = OmegaConf.structured(module.StructuredOptional) - msg = re.escape("child 'not_optional' is not Optional") + msg = re.escape("field 'not_optional' is not Optional") with raises(ValidationError, match=msg): cfg.not_optional = None diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 2d14570e6..2a7999dfd 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -808,10 +808,7 @@ def test_set_list_correct_type(self, module: Any) -> None: 3.1415, ["foo", True, 1.2], User(), - param( - [None], - marks=mark.xfail, # https://github.com/omry/omegaconf/issues/579 - ), + param([None]), ], ) def test_assign_wrong_type_to_list(self, module: Any, value: Any) -> None: @@ -825,10 +822,7 @@ def test_assign_wrong_type_to_list(self, module: Any, value: Any) -> None: @mark.parametrize( "value", [ - param( - None, - marks=mark.xfail, # https://github.com/omry/omegaconf/issues/579 - ), + param(None), True, "str", 3.1415, @@ -849,10 +843,7 @@ def test_insert_wrong_type_to_list(self, module: Any, value: Any) -> None: 3.1415, ["foo", True, 1.2], {"foo": True}, - param( - {"foo": None}, - marks=mark.xfail, # expected failure, https://github.com/omry/omegaconf/issues/579 - ), + param({"foo": None}), User(age=1, name="foo"), {"user": User(age=1, name="foo")}, ListConfig(content=[1, 2], ref_type=List[int], element_type=int), diff --git a/tests/test_base_config.py b/tests/test_base_config.py index f5d9e338e..3fc9c9bd3 100644 --- a/tests/test_base_config.py +++ b/tests/test_base_config.py @@ -1,5 +1,5 @@ import copy -from typing import Any, Dict, Union +from typing import Any, Dict, List, Optional, Union from pytest import mark, param, raises @@ -7,6 +7,7 @@ AnyNode, Container, DictConfig, + DictKeyType, IntegerNode, ListConfig, OmegaConf, @@ -17,9 +18,20 @@ open_dict, read_write, ) -from omegaconf._utils import nullcontext +from omegaconf._utils import _ensure_container, nullcontext from omegaconf.errors import ConfigAttributeError, ConfigKeyError, MissingMandatoryValue -from tests import StructuredWithMissing +from tests import ( + ConcretePlugin, + Group, + OptionalUsers, + StructuredWithMissing, + SubscriptedDict, + SubscriptedDictOpt, + SubscriptedList, + SubscriptedListOpt, + User, + Users, +) @mark.parametrize( @@ -587,3 +599,135 @@ def test_flags_root() -> None: cfg.a._set_flags_root(True) assert cfg.a._get_flag_no_cache("flag") is None + + +@mark.parametrize( + "cls,key,assignment,error", + [ + param(SubscriptedList, "list", [None], True, id="list_elt"), + param(SubscriptedList, "list", [0, 1, None], True, id="list_elt_partial"), + param(SubscriptedDict, "dict_str", {"key": None}, True, id="dict_elt"), + param( + SubscriptedDict, + "dict_str", + {"key_valid": 123, "key_invalid": None}, + True, + id="dict_elt_partial", + ), + param(SubscriptedList, "list", None, True, id="list"), + param(SubscriptedDict, "dict_str", None, True, id="dict"), + param(SubscriptedListOpt, "opt_list", [None], True, id="opt_list_elt"), + param(SubscriptedDictOpt, "opt_dict", {"key": None}, True, id="opt_dict_elt"), + param(SubscriptedListOpt, "opt_list", None, False, id="opt_list"), + param(SubscriptedDictOpt, "opt_dict", None, False, id="opt_dict"), + param(SubscriptedListOpt, "list_opt", [None], False, id="list_opt_elt"), + param(SubscriptedDictOpt, "dict_opt", {"key": None}, False, id="dict_opt_elt"), + param(SubscriptedListOpt, "list_opt", None, True, id="list_opt"), + param(SubscriptedDictOpt, "dict_opt", None, True, id="dict_opt"), + param( + ListConfig([None], element_type=Optional[User]), + 0, + User("Bond", 7), + False, + id="set_optional_user", + ), + param( + ListConfig([User], element_type=User), + 0, + None, + True, + id="illegal_set_user_to_none", + ), + ], +) +def test_optional_assign(cls: Any, key: str, assignment: Any, error: bool) -> None: + cfg = OmegaConf.structured(cls) + if error: + with raises(ValidationError): + cfg[key] = assignment + else: + cfg[key] = assignment + assert cfg[key] == assignment + + +@mark.parametrize( + "src,keys,ref_type,is_optional", + [ + param(Group, ["admin"], User, True, id="opt_user"), + param( + ConcretePlugin, + ["params"], + ConcretePlugin.FoobarParams, + False, + id="nested_structured_conf", + ), + param( + OmegaConf.structured(Users({"user007": User("Bond", 7)})).name2user, + ["user007"], + User, + False, + id="structured_dict_of_user", + ), + param( + DictConfig({"a": 123}, element_type=int), ["a"], int, False, id="dict_int" + ), + param( + DictConfig({"a": 123}, element_type=Optional[int]), + ["a"], + int, + True, + id="dict_opt_int", + ), + param(DictConfig({"a": 123}), ["a"], Any, True, id="dict_any"), + param( + OmegaConf.merge(Users, {"name2user": {"joe": User("joe")}}), + ["name2user", "joe"], + User, + False, + id="dict:merge_into_new_user_node", + ), + param( + OmegaConf.merge(OptionalUsers, {"name2user": {"joe": User("joe")}}), + ["name2user", "joe"], + User, + True, + id="dict:merge_into_new_optional_user_node", + ), + param( + OmegaConf.merge(ListConfig([], element_type=User), [User(name="joe")]), + [0], + User, + False, + id="list:merge_into_new_user_node", + ), + param( + OmegaConf.merge( + ListConfig([], element_type=Optional[User]), [User(name="joe")] + ), + [0], + User, + True, + id="list:merge_into_new_optional_user_node", + ), + param(SubscriptedDictOpt, ["opt_dict"], Dict[str, int], True, id="opt_dict"), + param(SubscriptedListOpt, ["opt_list"], List[int], True, id="opt_list"), + param( + SubscriptedDictOpt, + ["dict_opt"], + Dict[str, Optional[int]], + False, + id="opt_dict", + ), + param( + SubscriptedListOpt, ["list_opt"], List[Optional[int]], False, id="opt_dict" + ), + ], +) +def test_assignment_optional_behavior( + src: Any, keys: List[DictKeyType], ref_type: Any, is_optional: bool +) -> None: + cfg = _ensure_container(src) + for k in keys: + cfg = cfg._get_node(k) + assert cfg._is_optional() == is_optional + assert cfg._metadata.ref_type == ref_type diff --git a/tests/test_basic_ops_list.py b/tests/test_basic_ops_list.py index 950ce7253..5bdb57607 100644 --- a/tests/test_basic_ops_list.py +++ b/tests/test_basic_ops_list.py @@ -6,7 +6,8 @@ from pytest import mark, param, raises from omegaconf import MISSING, AnyNode, DictConfig, ListConfig, OmegaConf, flag_override -from omegaconf._utils import nullcontext +from omegaconf._utils import _ensure_container, nullcontext +from omegaconf.base import Node from omegaconf.errors import ( ConfigTypeError, InterpolationKeyError, @@ -446,6 +447,86 @@ def validate_list_keys(c: Any) -> None: assert c._get_node(i)._metadata.key == i +@mark.parametrize( + "cfg, value, expected, expected_ref_type", + [ + param( + ListConfig(element_type=int, content=[]), + 123, + [123], + int, + id="typed_list", + ), + param( + ListConfig(element_type=int, content=[]), + None, + ValidationError, + None, + id="typed_list_append_none", + ), + param( + ListConfig(element_type=Optional[int], content=[]), + 123, + [123], + int, + id="optional_typed_list", + ), + param( + ListConfig(element_type=Optional[int], content=[]), + None, + [None], + int, + id="optional_typed_list_append_none", + ), + param( + ListConfig(element_type=User, content=[]), + User(name="bond"), + [User(name="bond")], + User, + id="user_list", + ), + param( + ListConfig(element_type=User, content=[]), + None, + ValidationError, + None, + id="user_list_append_none", + ), + param( + ListConfig(element_type=Optional[User], content=[]), + User(name="bond"), + [User(name="bond")], + User, + id="optional_user_list", + ), + param( + ListConfig(element_type=Optional[User], content=[]), + None, + [None], + User, + id="optional_user_list_append_none", + ), + ], +) +def test_append_to_typed( + cfg: ListConfig, + value: Any, + expected: Any, + expected_ref_type: type, +) -> None: + cfg = _ensure_container(cfg) + if isinstance(expected, type): + with raises(expected): + cfg.append(value) + else: + cfg.append(value) + assert cfg == expected + node = cfg._get_node(-1) + assert isinstance(node, Node) + assert node._metadata.ref_type == expected_ref_type + validate_list_keys(cfg) + + @mark.parametrize( "input_, index, value, expected, expected_node_type, expectation", [ @@ -475,6 +556,24 @@ def validate_list_keys(c: Any) -> None: None, ValidationError, ), + param( + ListConfig(element_type=int, content=[]), + 0, + 123, + [123], + IntegerNode, + None, + id="typed_list", + ), + param( + ListConfig(element_type=int, content=[]), + 0, + None, + None, + None, + ValidationError, + id="typed_list_insert_none", + ), ], ) def test_insert( @@ -786,3 +885,37 @@ def test_node_copy_on_append(node: Any) -> None: cfg = OmegaConf.create([]) cfg.append(node) assert cfg.__dict__["_content"][0] is not node + + +@mark.parametrize( + "cfg,key,value,error", + [ + param( + ListConfig([], element_type=Optional[User]), + 0, + "foo", + True, + id="structured:set_optional_to_bad_type", + ), + param( + ListConfig([], element_type=int), + 0, + None, + True, + id="set_to_none_raises", + ), + param( + ListConfig([], element_type=Optional[int]), + 0, + None, + False, + id="optional_set_to_none", + ), + ], +) +def test_validate_set(cfg: ListConfig, key: int, value: Any, error: bool) -> None: + if error: + with raises(ValidationError): + cfg._validate_set(key, value) + else: + cfg._validate_set(key, value) diff --git a/tests/test_errors.py b/tests/test_errors.py index f0bc6158e..e2ac89a4e 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -131,7 +131,7 @@ def finalize(self, cfg: Any) -> None: create=lambda: OmegaConf.structured(StructuredWithMissing), op=lambda cfg: OmegaConf.update(cfg, "num", None), exception_type=ValidationError, - msg="child 'num' is not Optional", + msg="field 'num' is not Optional", parent_node=lambda cfg: cfg, child_node=lambda cfg: cfg._get_node("num"), object_type=StructuredWithMissing, @@ -370,7 +370,7 @@ def finalize(self, cfg: Any) -> None: ), op=lambda cfg: setattr(cfg, "foo", None), exception_type=ValidationError, - msg="child 'foo' is not Optional", + msg="field 'foo' is not Optional", key="foo", full_key="foo", child_node=lambda cfg: cfg.foo, @@ -648,6 +648,16 @@ def finalize(self, cfg: Any) -> None: ), id="dict:create:not_optional_A_field_with_none", ), + param( + Expected( + create=lambda: DictConfig({}, element_type=str), + op=lambda cfg: OmegaConf.merge(cfg, {"foo": None}), + exception_type=ValidationError, + key="foo", + msg="field 'foo' is not Optional", + ), + id="dict:merge_none_into_not_optional_element_type", + ), param( Expected( create=lambda: None, @@ -709,6 +719,15 @@ def finalize(self, cfg: Any) -> None: ), id="structured:create_from_unsupported_object", ), + param( + Expected( + create=lambda: None, + op=lambda _: DictConfig({}, element_type=IllegalType), + exception_type=ValidationError, + msg="Unsupported value type: 'tests.IllegalType'", + ), + id="structured:create_with_unsupported_element_type", + ), param( Expected( create=lambda: None, diff --git a/tests/test_matrix.py b/tests/test_matrix.py index 5cce208a1..d26ae740e 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -111,7 +111,7 @@ def test_none_assignment_and_merging_in_dict( data = {"node": node} cfg = OmegaConf.create(obj=data) verify(cfg, "node", none=False, opt=False, missing=False, inter=False) - msg = "child 'node' is not Optional" + msg = "field 'node' is not Optional" with raises(ValidationError, match=re.escape(msg)): cfg.node = None diff --git a/tests/test_merge.py b/tests/test_merge.py index c982bbc91..e28b83cff 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -1,7 +1,19 @@ import copy +import re import sys -from typing import Any, Dict, List, MutableMapping, MutableSequence, Tuple, Union +from textwrap import dedent +from typing import ( + Any, + Dict, + List, + MutableMapping, + MutableSequence, + Optional, + Tuple, + Union, +) +from _pytest.python_api import RaisesContext from pytest import mark, param, raises from omegaconf import ( @@ -12,7 +24,8 @@ ReadonlyConfigError, ValidationError, ) -from omegaconf._utils import is_structured_config +from omegaconf._utils import _ensure_container, is_structured_config +from omegaconf.base import Node from omegaconf.errors import ConfigKeyError, UnsupportedValueType from omegaconf.nodes import IntegerNode from tests import ( @@ -28,6 +41,7 @@ InterpolationList, MissingDict, MissingList, + OptionalUsers, OptTuple, Package, Plugin, @@ -218,6 +232,16 @@ {"name2user": {"joe": {"name": "joe", "age": MISSING}}}, id="users_merge_with_missing_age", ), + param( + [OptionalUsers, {"name2user": {"joe": {"name": "joe"}}}], + {"name2user": {"joe": {"name": "joe", "age": MISSING}}}, + id="optionalusers_merge_with_missing_age", + ), + param( + [OptionalUsers, {"name2user": {"joe": None}}], + {"name2user": {"joe": None}}, + id="optionalusers_merge_with_none", + ), param( [ConfWithMissingDict, {"dict": {"foo": "bar"}}], {"dict": {"foo": "bar"}}, @@ -283,7 +307,7 @@ ), param( ( - DictConfig({"user007": None}, element_type=User), + DictConfig({"user007": None}, element_type=Optional[User]), {"user007": {"age": 99}}, ), {"user007": {"name": "???", "age": 99}}, @@ -365,6 +389,373 @@ def test_merge( merge_function(*configs) +@mark.parametrize( + "inputs,expected,ref_type,is_optional", + [ + param( + (DictConfig(content={"foo": "bar"}, element_type=str), {"foo": "qux"}), + {"foo": "qux"}, + str, + False, + id="str", + ), + param( + (DictConfig(content={"foo": "bar"}, element_type=str), {"foo": None}), + raises( + ValidationError, + match="Incompatible value 'None' for field of type 'str'", + ), + None, + None, + id="str_none", + ), + param( + (DictConfig(content={"foo": "bar"}, element_type=str), {"foo": MISSING}), + {"foo": "bar"}, + str, + False, + id="str_missing", + ), + param( + ( + DictConfig(content={"foo": "bar"}, element_type=Optional[str]), + {"foo": "qux"}, + ), + {"foo": "qux"}, + str, + True, + id="optional_str", + ), + param( + ( + DictConfig(content={"foo": "bar"}, element_type=Optional[str]), + {"foo": None}, + ), + {"foo": None}, + str, + True, + id="optional_str_none", + ), + param( + ( + DictConfig(content={"foo": "bar"}, element_type=Optional[str]), + {"foo": MISSING}, + ), + {"foo": "bar"}, + str, + True, + id="optional_str_missing", + ), + param( + (DictConfig(content={}, element_type=str), {"foo": "qux"}), + {"foo": "qux"}, + str, + False, + id="new_str", + ), + param( + (DictConfig(content={}, element_type=str), {"foo": None}), + raises( + ValidationError, + match="field 'foo' is not Optional", + ), + None, + None, + id="new_str_none", + ), + param( + (DictConfig(content={}, element_type=str), {"foo": MISSING}), + {"foo": MISSING}, + str, + False, + id="new_str_missing", + ), + param( + (DictConfig(content={}, element_type=Optional[str]), {"foo": "qux"}), + {"foo": "qux"}, + str, + True, + id="new_optional_str", + ), + param( + (DictConfig(content={}, element_type=Optional[str]), {"foo": None}), + {"foo": None}, + str, + True, + id="new_optional_str_none", + ), + param( + (DictConfig(content={}, element_type=Optional[str]), {"foo": MISSING}), + {"foo": MISSING}, + str, + True, + id="new_optional_str_missing", + ), + param( + (DictConfig(content={"foo": MISSING}, element_type=str), {"foo": "qux"}), + {"foo": "qux"}, + str, + False, + id="missing_str", + ), + param( + (DictConfig(content={"foo": MISSING}, element_type=str), {"foo": None}), + raises( + ValidationError, + match="Incompatible value 'None' for field of type 'str'", + ), + None, + None, + id="missing_str_none", + ), + param( + (DictConfig(content={"foo": MISSING}, element_type=str), {"foo": MISSING}), + {"foo": MISSING}, + str, + False, + id="missing_str_missing", + ), + param( + ( + DictConfig(content={"foo": MISSING}, element_type=Optional[str]), + {"foo": "qux"}, + ), + {"foo": "qux"}, + str, + True, + id="missing_optional_str", + ), + param( + ( + DictConfig(content={"foo": MISSING}, element_type=Optional[str]), + {"foo": None}, + ), + {"foo": None}, + str, + True, + id="missing_optional_str_none", + ), + param( + ( + DictConfig(content={"foo": MISSING}, element_type=Optional[str]), + {"foo": MISSING}, + ), + {"foo": MISSING}, + str, + True, + id="missing_optional_str_missing", + ), + param( + ( + DictConfig(content={"foo": User("Bond")}, element_type=User), + {"foo": User("007")}, + ), + {"foo": User("007")}, + User, + False, + id="user", + ), + param( + ( + DictConfig(content={"foo": User("Bond")}, element_type=User), + {"foo": None}, + ), + raises( + ValidationError, + match=re.escape( + dedent( + """\ + field 'foo' is not Optional + full_key: foo + reference_type=User + object_type=User""" + ) + ), + ), + None, + None, + id="user_none", + ), + param( + ( + DictConfig(content={"foo": User("Bond")}, element_type=User), + {"foo": MISSING}, + ), + {"foo": User("Bond")}, + User, + False, + id="user_missing", + ), + param( + ( + DictConfig(content={"foo": User("Bond")}, element_type=Optional[User]), + {"foo": User("007")}, + ), + {"foo": User("007")}, + User, + True, + id="optional_user", + ), + param( + ( + DictConfig(content={"foo": User("Bond")}, element_type=Optional[User]), + {"foo": None}, + ), + {"foo": None}, + User, + True, + id="optional_user_none", + ), + param( + ( + DictConfig(content={"foo": User("Bond")}, element_type=Optional[User]), + {"foo": MISSING}, + ), + {"foo": User("Bond")}, + User, + True, + id="optional_user_missing", + ), + param( + (DictConfig(content={}, element_type=User), {"foo": User("Bond")}), + {"foo": User("Bond")}, + User, + False, + id="new_user", + ), + param( + (DictConfig(content={}, element_type=User), {"foo": None}), + raises( + ValidationError, + match=re.escape( + dedent( + """\ + field 'foo' is not Optional + full_key: foo + object_type=dict""" + ) + ), + ), + None, + None, + id="new_user_none", + ), + param( + (DictConfig(content={}, element_type=User), {"foo": MISSING}), + {"foo": MISSING}, + User, + False, + id="new_user_missing", + ), + param( + ( + DictConfig(content={}, element_type=Optional[User]), + {"foo": User("Bond")}, + ), + {"foo": User("Bond")}, + User, + True, + id="new_optional_user", + ), + param( + (DictConfig(content={}, element_type=Optional[User]), {"foo": None}), + {"foo": None}, + User, + True, + id="new_optional_user_none", + ), + param( + (DictConfig(content={}, element_type=Optional[User]), {"foo": MISSING}), + {"foo": MISSING}, + User, + True, + id="new_optional_user_missing", + ), + param( + ( + DictConfig(content={"foo": MISSING}, element_type=User), + {"foo": User("Bond")}, + ), + {"foo": User("Bond")}, + User, + False, + id="missing_user", + ), + param( + (DictConfig(content={"foo": MISSING}, element_type=User), {"foo": None}), + raises( + ValidationError, + match=re.escape( + dedent( + """\ + field 'foo' is not Optional + full_key: foo + reference_type=User + object_type=NoneType""" + ) + ), + ), + None, + None, + id="missing_user_none", + ), + param( + (DictConfig(content={"foo": MISSING}, element_type=User), {"foo": MISSING}), + {"foo": MISSING}, + User, + False, + id="missing_user_missing", + ), + param( + ( + DictConfig(content={"foo": MISSING}, element_type=Optional[User]), + {"foo": User("Bond")}, + ), + {"foo": User("Bond")}, + User, + True, + id="missing_optional_user", + ), + param( + ( + DictConfig(content={"foo": MISSING}, element_type=Optional[User]), + {"foo": None}, + ), + {"foo": None}, + User, + True, + id="missing_optional_user_none", + ), + param( + ( + DictConfig(content={"foo": MISSING}, element_type=Optional[User]), + {"foo": MISSING}, + ), + {"foo": MISSING}, + User, + True, + id="missing_optional_user_missing", + ), + ], +) +def test_optional_element_type_merge( + inputs: Any, expected: Any, ref_type: Any, is_optional: bool +) -> None: + configs = [_ensure_container(c) for c in inputs] + if isinstance(expected, RaisesContext): + with expected: + OmegaConf.merge(*configs) + else: + cfg = OmegaConf.merge(*configs) + assert cfg == expected + + assert isinstance(cfg, DictConfig) + node = cfg._get_node("foo") + assert isinstance(node, Node) + assert node._is_optional() == is_optional + assert node._metadata.ref_type == ref_type + + def test_merge_error_retains_type() -> None: cfg = OmegaConf.structured(ConcretePlugin) with raises(ValidationError): diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 522b37c4f..5c811fa5d 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -3,6 +3,8 @@ import os import pathlib import pickle +import re +import sys import tempfile from pathlib import Path from typing import Any, Dict, List, Optional, Type @@ -11,12 +13,15 @@ from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf from omegaconf._utils import get_ref_type +from omegaconf.errors import OmegaConfBaseException from tests import ( Color, PersonA, PersonD, SubscriptedDict, + SubscriptedDictOpt, SubscriptedList, + SubscriptedListOpt, UntypedDict, UntypedList, ) @@ -153,24 +158,101 @@ def test_load_empty_file(tmpdir: str) -> None: @mark.parametrize( "input_,node,element_type,key_type,optional,ref_type", [ - (UntypedList, "list", Any, Any, False, List[Any]), - (UntypedList, "opt_list", Any, Any, True, Optional[List[Any]]), - (UntypedDict, "dict", Any, Any, False, Dict[Any, Any]), - ( + param(UntypedList, "list", Any, Any, False, List[Any], id="list_untyped"), + param( + UntypedList, + "opt_list", + Any, + Any, + True, + Optional[List[Any]], + id="opt_list_untyped", + ), + param(UntypedDict, "dict", Any, Any, False, Dict[Any, Any], id="dict_untyped"), + param( UntypedDict, "opt_dict", Any, Any, True, Optional[Dict[Any, Any]], + id="opt_dict_untyped", ), - (SubscriptedDict, "dict_str", int, str, False, Dict[str, int]), - (SubscriptedDict, "dict_int", int, int, False, Dict[int, int]), - (SubscriptedDict, "dict_bool", int, bool, False, Dict[bool, int]), - (SubscriptedDict, "dict_float", int, float, False, Dict[float, int]), - (SubscriptedDict, "dict_enum", int, Color, False, Dict[Color, int]), - (SubscriptedList, "list", int, Any, False, List[int]), - ( + param( + SubscriptedDict, "dict_str", int, str, False, Dict[str, int], id="dict_str" + ), + param( + SubscriptedDict, "dict_int", int, int, False, Dict[int, int], id="dict_int" + ), + param( + SubscriptedDict, + "dict_bool", + int, + bool, + False, + Dict[bool, int], + id="dict_bool", + ), + param( + SubscriptedDict, + "dict_float", + int, + float, + False, + Dict[float, int], + id="dict_float", + ), + param( + SubscriptedDict, + "dict_enum", + int, + Color, + False, + Dict[Color, int], + id="dict_enum", + ), + param(SubscriptedList, "list", int, Any, False, List[int], id="list_int"), + param( + SubscriptedDictOpt, + "opt_dict", + int, + str, + True, + Optional[Dict[str, int]], + marks=mark.skipif(sys.version_info < (3, 7), reason="requires Python 3.7"), + id="opt_dict", + ), + param( + SubscriptedDictOpt, + "dict_opt", + Optional[int], + str, + False, + Dict[str, Optional[int]], + marks=mark.skipif(sys.version_info < (3, 7), reason="requires Python 3.7"), + id="dict_opt", + ), + param( + SubscriptedListOpt, + "opt_list", + int, + str, + True, + Optional[List[int]], + marks=mark.skipif(sys.version_info < (3, 7), reason="requires Python 3.7"), + id="opt_list", + ), + param( + SubscriptedListOpt, + "list_opt", + Optional[int], + str, + False, + List[Optional[int]], + marks=mark.skipif(sys.version_info < (3, 7), reason="requires Python 3.7"), + id="list_opt", + ), + param( DictConfig( content={"a": "foo"}, ref_type=Dict[str, str], @@ -271,3 +353,15 @@ def test_pickle_backward_compatibility(version: str) -> None: with open(path, mode="rb") as fp: cfg = pickle.load(fp) assert cfg == OmegaConf.create({"a": [{"b": 10}]}) + + +@mark.skipif(sys.version_info >= (3, 7), reason="requires python3.6") +def test_python36_pickle_optional() -> None: + cfg = OmegaConf.structured(SubscriptedDictOpt) + with raises( + OmegaConfBaseException, + match=re.escape( + "Serializing structured configs with `Union` element type requires python >= 3.7" + ), + ): + pickle.dumps(cfg)