Skip to content

Commit

Permalink
Add support for builds(Dataclass, populate_full_signature=True) when …
Browse files Browse the repository at this point in the history
…dataclass has default factory fields (#299)

* add support for populating field with default factory

* improve tests and add regression for hydra 2350

* update changelog
  • Loading branch information
rsokl authored Aug 19, 2022
1 parent 4d90e31 commit 7675970
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 6 deletions.
5 changes: 3 additions & 2 deletions docs/source/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ hydra-zen :ref:`already provides support for these <additional-types>`, but this

Improvements
------------
- The :ref:`automatic type refinement <type-support>` performed by :func:`~hydra_zen.builds` now has enhanced support for ``typing.Annotated``, ``typing.NewType``, and ``typing.TypeVarTuple``. (See :pull:`283`)
- :func:`~hydra_zen.builds` no longer has restrictions on inheritance patterns involving `PartialBuilds`-type configs. (See :pull:`290`)
- Two new utility functions were added to the public API: :func:`~hydra_zen.is_partial_builds` and :func:`~hydra_zen.uses_zen_processing`
- Improved :ref:`automatic type refinement <type-support>` for bare sequence types, and parity with support for `dict`, `list`, and `tuple` as type annotations in omegaconf 2.2.3+. (See :pull:`297`)
- Improved :ref:`automatic type refinement <type-support>` for bare sequence types, and adds conditional support for `dict`, `list`, and `tuple` as type annotations when omegaconf 2.2.3+ is installed. (See :pull:`297`)
- Adds support for using `builds(<target>, populate_full_signature=True)` where `<target>` is a dataclass type that has a field with a default factory. (See :pull:`299`)
- The :ref:`automatic type refinement <type-support>` performed by :func:`~hydra_zen.builds` now has enhanced support for ``typing.Annotated``, ``typing.NewType``, and ``typing.TypeVarTuple``. (See :pull:`283`)

Bug Fixes
---------
Expand Down
24 changes: 22 additions & 2 deletions src/hydra_zen/structured_configs/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,15 +1367,15 @@ def builds(
# We want to rely on `inspect.signature` logic for raising
# against an uninspectable sig, before we start inspecting
# class-specific attributes below.
signature_params = inspect.signature(target).parameters
signature_params = dict(inspect.signature(target).parameters)
except ValueError:
if populate_full_signature:
raise ValueError(
BUILDS_ERROR_PREFIX
+ f"{target} does not have an inspectable signature. "
f"`builds({_utils.safe_name(target)}, populate_full_signature=True)` is not supported"
)
signature_params: Mapping[str, inspect.Parameter] = {}
signature_params: Dict[str, inspect.Parameter] = {}
# We will turn off signature validation for objects that didn't have
# a valid signature. This will enable us to do things like `build(dict, a=1)`
target_has_valid_signature: bool = False
Expand Down Expand Up @@ -1414,6 +1414,26 @@ def builds(

target_has_valid_signature: bool = True

if is_dataclass(target):
# Update `signature_params` so that any param with `default=<factory>`
# has its default replaced with `<factory>()`
# If this is a mutable value, `builds` will automatically re-pack
# it using a default factory
_fields = {f.name: f for f in fields(target)}
_update = {}
for name, param in signature_params.items():
f = _fields[name]
if f.default_factory is not MISSING:
_update[name] = inspect.Parameter(
name,
param.kind,
annotation=param.annotation,
default=f.default_factory(),
)
signature_params.update(_update)
del _update
del _fields

# `get_type_hints` properly resolves forward references, whereas annotations from
# `inspect.signature` do not
try:
Expand Down
102 changes: 100 additions & 2 deletions tests/test_value_conversion.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) 2022 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
import inspect
import string
from collections import Counter, deque
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import Dict, FrozenSet, List, Set, Union
from typing import Any, Callable, Dict, FrozenSet, Iterable, List, Set, Union

import hypothesis.strategies as st
import pytest
Expand All @@ -23,6 +24,8 @@
)
from hydra_zen._compatibility import ZEN_SUPPORTED_PRIMITIVES
from hydra_zen.structured_configs._value_conversion import ZEN_VALUE_CONVERSION
from hydra_zen.typing import Partial
from tests import is_same


def test_supported_primitives_in_sync_with_value_conversion():
Expand Down Expand Up @@ -238,3 +241,98 @@ class C:
with pytest.raises((ValidationError, AssertionError)):
Conf = OmegaConf.create(C)
assert isinstance(Conf.x, type_)


@dataclass
class A_builds_populate_sig_with_default_factory:
z: Any
x_list: List[int] = field(default_factory=lambda: list([1, 0, 1, 0, 1]))
x_dict: Dict[str, int] = field(default_factory=lambda: dict({"K_DEFAULT": 10101}))
y: bool = False


@given(
via_yaml=st.booleans(),
list_=st.none() | st.lists(st.integers()),
# TODO: generalize to st.dictionaries(st.sampled_from("abcd"), st.integers())
# once https://github.com/facebookresearch/hydra/issues/2350 is resolved
dict_=st.none(),
kwargs_via_builds=st.booleans(),
)
def test_builds_populate_sig_with_default_factory(
via_yaml: bool, list_, dict_, kwargs_via_builds
):
A = A_builds_populate_sig_with_default_factory
kwargs = {}
if list_ is not None:
kwargs["x_list"] = list_

if dict_ is not None:
kwargs["x_dict"] = dict_

Conf = (
builds(A, **kwargs, populate_full_signature=True)
if kwargs_via_builds
else builds(A, populate_full_signature=True)
)

assert inspect.signature(Conf).parameters == inspect.signature(A).parameters

if via_yaml:
Conf = OmegaConf.structured(to_yaml(Conf))
a_expected = A(z=1, **kwargs)

a = (
instantiate(Conf, z=1)
if kwargs_via_builds
else instantiate(Conf, **kwargs, z=1)
)

assert isinstance(a, A)
assert a.x_list == a_expected.x_list
assert a.x_dict == a_expected.x_dict
assert a.y == a_expected.y
assert a.z == a_expected.z


@dataclass
class A_auto_config_for_dataclass_fields:
complex_factory: Any = mutable_value(1 + 2j)
complex_: complex = 2 + 4j
list_of_stuff: List[Any] = field(
default_factory=lambda: list([1 + 2j, Path.home()])
)
fn_factory: Callable[[Iterable[int]], int] = field(default_factory=lambda: sum)
fn: Callable[[Iterable[int]], int] = sum
partial_factory: Partial[int] = field(default_factory=lambda: partial(sum))
partial_: Partial[int] = partial(sum)


def test_auto_config_for_dataclass_fields():
A = A_auto_config_for_dataclass_fields
Conf = builds(A, populate_full_signature=True)
actual = instantiate(Conf)
expected = A()
assert isinstance(actual, A)
assert actual.complex_factory == expected.complex_factory
assert actual.complex_ == expected.complex_
assert actual.list_of_stuff == expected.list_of_stuff
assert is_same(actual.fn_factory, expected.fn_factory)
assert is_same(actual.fn, expected.fn)
assert is_same(actual.partial_factory, expected.partial_factory)
assert is_same(actual.partial_, expected.partial_)


def identity_with_dict_default(x={"a": 1}):
return x


@pytest.mark.xfail
def test_known_failcase_hydra_2350():
# https://github.com/facebookresearch/hydra/issues/2350
# Overriding a default-value dictionary via instantiate interface results
# in merging of the default-dictionary with the override value
Conf = builds(identity_with_dict_default, populate_full_signature=True)
actual = instantiate(Conf, x={"b": 2})
expected = {"b": 2}
assert actual == expected, actual

0 comments on commit 7675970

Please sign in to comment.