Skip to content

Commit

Permalink
Merge pull request #588 from mit-ll-responsible-ai/auto-config-fix
Browse files Browse the repository at this point in the history
fix store autoconfig issue
  • Loading branch information
rsokl authored Nov 15, 2023
2 parents e31c388 + 3985e35 commit 16825fc
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 21 deletions.
17 changes: 12 additions & 5 deletions src/hydra_zen/structured_configs/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,10 @@ def _check_instance(*target_types: str, value: "Any", module: str): # pragma: n
_check_instance, "required", module="torch.optim.optimizer"
)

_is_pydantic_BaseModel = functools.partial(
_check_instance, "BaseModel", module="pydantic"
)


def _check_for_dynamically_defined_dataclass_type(target_path: str, value: Any) -> None:
if target_path.startswith("types."):
Expand Down Expand Up @@ -1096,13 +1100,16 @@ def _make_hydra_compatible(
pydantic = sys.modules.get("pydantic")

if pydantic is not None: # pragma: no cover
if isinstance(value, pydantic.fields.FieldInfo):
if _check_instance("FieldInfo", module="pydantic.fields", value=value):
_val = (
value.default_factory() # type: ignore
if value.default_factory is not None # type: ignore
else value.default # type: ignore
)
if isinstance(_val, pydantic.fields.UndefinedType):

if _check_instance(
"UndefinedType", module="pydantic.fields", value=_val
):
return MISSING

return cls._make_hydra_compatible(
Expand All @@ -1115,11 +1122,11 @@ def _make_hydra_compatible(
hydra_convert=hydra_convert,
hydra_recursive=hydra_recursive,
)
if isinstance(value, pydantic.BaseModel):
if _is_pydantic_BaseModel(value=value):
return cls.builds(type(value), **value.__dict__)

if isinstance(value, str) or (
pydantic is not None and isinstance(value, pydantic.AnyUrl)
if isinstance(value, str) or _check_instance(
"AnyUrl", module="pydantic", value=value
):
# Supports pydantic.AnyURL
_v = str(value)
Expand Down
32 changes: 17 additions & 15 deletions src/hydra_zen/wrapper/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import partial, wraps
from inspect import Parameter, iscoroutinefunction, signature
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Expand Down Expand Up @@ -48,7 +49,7 @@
from hydra_zen import instantiate
from hydra_zen._compatibility import HYDRA_VERSION, Version
from hydra_zen.errors import HydraZenValidationError
from hydra_zen.structured_configs._implementations import BuildsFn, DefaultBuilds
from hydra_zen.structured_configs._implementations import DefaultBuilds
from hydra_zen.structured_configs._type_guards import safe_getattr
from hydra_zen.typing._implementations import (
DataClass_,
Expand All @@ -61,9 +62,12 @@
from ..structured_configs._type_guards import is_dataclass
from ..structured_configs._utils import safe_name

if TYPE_CHECKING:
from hydra_zen import BuildsFn


__all__ = ["zen", "store", "Zen"]

get_obj_path = BuildsFn._get_obj_path # type: ignore

R = TypeVar("R")
P = ParamSpec("P")
Expand Down Expand Up @@ -830,7 +834,7 @@ def default_to_config(
ListConfig,
DictConfig,
],
BuildsFn: Type[BuildsFn[Any]] = DefaultBuilds,
CustomBuildsFn: Type["BuildsFn[Any]"] = DefaultBuilds,
**kw: Any,
) -> Union[DataClass_, Type[DataClass_], ListConfig, DictConfig]:
"""Creates a config that describes `target`.
Expand All @@ -849,7 +853,7 @@ def default_to_config(
----------
target : Callable[..., Any] | DataClass | Type[DataClass] | list | dict
BuildsFn : Type[BuildsFn[Any]], optional (default=DefaultBuilds)
CustomBuildsFn : Type[BuildsFn[Any]], optional (default=DefaultBuilds)
Provides the config-creation functions (`builds`, `just`) used
by this function.
Expand Down Expand Up @@ -893,21 +897,20 @@ def default_to_config(
'y': ???
"""

kw = kw.copy()

if is_dataclass(target):
if isinstance(target, type):
if issubclass(target, HydraConf):
# don't auto-config HydraConf
return target

if not kw and get_obj_path(target).startswith("types."):
if not kw and CustomBuildsFn._get_obj_path(target).startswith("types."): # type: ignore
# handles dataclasses returned by make_config()
return target
return BuildsFn.builds(
target,
**kw,
populate_full_signature=True,
builds_bases=(target,),
)
kw.setdefault("populate_full_signature", True)
kw.setdefault("builds_bases", (target,))
return CustomBuildsFn.builds(target, **kw)
if kw:
raise ValueError(
"store(<dataclass-instance>, [...]) does not support specifying "
Expand All @@ -917,14 +920,13 @@ def default_to_config(

elif isinstance(target, (dict, list)):
# TODO: convert to OmegaConf containers?
return BuildsFn.just(target)
return CustomBuildsFn.just(target)
elif isinstance(target, (DictConfig, ListConfig)):
return target
else:
t = cast(Callable[..., Any], target)
return cast(
Type[DataClass_], BuildsFn.builds(t, **kw, populate_full_signature=True)
)
kw.setdefault("populate_full_signature", True)
return cast(Type[DataClass_], CustomBuildsFn.builds(t, **kw))


class _HasName(Protocol):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_BuildsFn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_zen_field():

def test_default_to_config():
store = ZenStore("my store")(
to_config=partial(default_to_config, BuildsFn=MyBuildsFn)
to_config=partial(default_to_config, CustomBuildsFn=MyBuildsFn)
)
store(A, x=A(x=2), name="blah")
assert instantiate(store[None, "blah"]) == A(x=A(x=2))
17 changes: 17 additions & 0 deletions tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
just,
make_config,
store as default_store,
to_yaml,
)
from tests.custom_strategies import new_stores, store_entries

Expand Down Expand Up @@ -982,3 +983,19 @@ def test_merge():
assert s3._queue == {(None, "a"), (None, "b"), (None, "c")}
assert s3._internal_repo[None, "a"] is not s1._internal_repo[None, "a"]
assert s3._internal_repo[None, "b"] is not s2._internal_repo[None, "b"]


def test_disable_pop_sig_autoconfig():
s = ZenStore()
s(ZenStore, populate_full_signature=False, name="s")
config = s[None, "s"]
assert len(to_yaml(config).splitlines()) == 1
sout = instantiate(s[None, "s"])
assert isinstance(sout, ZenStore)

s2 = ZenStore()
s2(ZenStore, name="s")
config2 = s2[None, "s"]
assert len(to_yaml(config2).splitlines()) > 1
sout2 = instantiate(s2[None, "s"])
assert isinstance(sout2, ZenStore)
7 changes: 7 additions & 0 deletions tests/test_third_party/test_using_v1_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2023 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
import dataclasses
import sys
from typing import Any, List, Optional

import hypothesis.strategies as st
Expand Down Expand Up @@ -28,6 +29,12 @@
)


def test_BaseModel():
_pydantic = sys.modules.get("pydantic")
assert _pydantic is not None
assert _pydantic.BaseModel is BaseModel


@parametrize_pydantic_fields
def test_pydantic_specific_fields_function(custom_type, good_val, bad_val):
def f(x):
Expand Down

0 comments on commit 16825fc

Please sign in to comment.