diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index fddcfe66..752de38f 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -12,9 +12,9 @@ Encyclopedia Hydrazennica. All reference documentation includes detailed Examples sections. Please scroll to the bottom of any given reference page to see the examples. -************************* -Launching Jobs with Hydra -************************* +************************************** +Creating and Launching Jobs with Hydra +************************************** hydra-zen provides users the ability to launch a Hydra job via a Python function instead of from a commandline interface. @@ -25,7 +25,8 @@ Python function instead of from a commandline interface. :toctree: generated/ launch - + zen + wrapper.Zen ********************************* Creating and Working with Configs diff --git a/docs/source/changes.rst b/docs/source/changes.rst index eb70d70a..04550603 100644 --- a/docs/source/changes.rst +++ b/docs/source/changes.rst @@ -10,6 +10,66 @@ chronological order. All previous releases should still be available on pip. .. _v0.8.0: +--------------------- +0.9.0rc2 - 2022-10-19 +--------------------- + +.. note:: This is documentation for an unreleased version of hydra-zen. You can try out this version pre-release version using `pip install --pre hydra-zen` + + +Release Highlights +------------------ +This release adds the :func:`~hydra_zen.zen` decorator, which enables users to use Hydra-agnostic task functions for their Hydra app; the decorator will automatically extract, resolve, and instantiate fields from an input config based on the function's signature. + +This encourages users to eliminate Hydra-specific boilerplate code from their projects and to instead opt for task functions with explicit signatures, including functions from third parties. + +E.g., :func:`~hydra_zen.zen` enables us to replace the following Hydra-specific task function: + +.. code-block:: python + :caption: The "old school" way of designing a task function for a Hydra app + + import hydra + from hydra.utils import instantiate + + @hydra.main(config_name="my_app", config_path=None, version_base="1.2") + def trainer_task_fn(cfg): + model = instantiate(cfg.model) + data = instantiate(cfg.data) + partial_optim = instantiate(cfg.partial_optim) + trainer = instantiate(cfg.trainer) + + optim = partial_optim(model.parameters()) + trainer(model, optim, data).fit(cfg.num_epochs) + + if __name__ == "__main__": + trainer_task_fn() + +with a Hydra-agnostic task function that has an explicit signature: + +.. code-block:: python + :caption: Using `zen` to design a Hydra-agnostic task function + + from hydra_zen import zen + + @zen + def trainer_task_fn(model, data, partial_optim, trainer, num_epochs: int): + # All config-field extraction & instantiation is automated/mediated by zen + optim = partial_optim(model.parameters()) + trainer(model, optim, data).fit(num_epochs) + + if __name__ == "__main__": + trainer_task_fn.hydra_main(config_name="my_app", config_path=None) + + +There are plenty more bells and whistles to :func:`~hydra_zen.zen`, refer to :pull:`310` and its reference documentation for more details. + +New Features +------------ +- Adds the :func:`~hydra_zen.zen` decorator (see :pull:`310`) +- Adds the :func:`~hydra_zen.wrapper.Zen` decorator-class (see :pull:`310`) + +.. _v0.8.0: + ------------------ 0.8.0 - 2022-09-13 ------------------ diff --git a/docs/source/generated/hydra_zen.wrapper.Zen.rst b/docs/source/generated/hydra_zen.wrapper.Zen.rst new file mode 100644 index 00000000..0b74bc37 --- /dev/null +++ b/docs/source/generated/hydra_zen.wrapper.Zen.rst @@ -0,0 +1,22 @@ +hydra\_zen.wrapper.Zen +====================== + +.. currentmodule:: hydra_zen.wrapper + +.. autoclass:: Zen + + + .. automethod:: __init__ + .. automethod:: __call__ + .. automethod:: validate + .. automethod:: hydra_main + + + .. rubric:: Methods + + .. autosummary:: + + ~Zen.__init__ + ~Zen.__call__ + ~Zen.validate + ~Zen.hydra_main diff --git a/docs/source/generated/hydra_zen.zen.rst b/docs/source/generated/hydra_zen.zen.rst new file mode 100644 index 00000000..dd6fdc11 --- /dev/null +++ b/docs/source/generated/hydra_zen.zen.rst @@ -0,0 +1,6 @@ +hydra\_zen.zen +=============== + +.. currentmodule:: hydra_zen + +.. autofunction:: zen \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 17c717f5..c83e57db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ python = [testenv] +passenv = * deps = pytest hypothesis commands = pytest \ diff --git a/src/hydra_zen/__init__.py b/src/hydra_zen/__init__.py index f4925e0b..e6851f80 100644 --- a/src/hydra_zen/__init__.py +++ b/src/hydra_zen/__init__.py @@ -21,6 +21,7 @@ ) from .structured_configs._implementations import get_target from .structured_configs._type_guards import is_partial_builds, uses_zen_processing +from .wrapper import zen __all__ = [ "builds", @@ -39,6 +40,7 @@ "launch", "is_partial_builds", "uses_zen_processing", + "zen", ] if not TYPE_CHECKING: diff --git a/src/hydra_zen/_compatibility.py b/src/hydra_zen/_compatibility.py index 4d516982..d7519c1d 100644 --- a/src/hydra_zen/_compatibility.py +++ b/src/hydra_zen/_compatibility.py @@ -36,6 +36,7 @@ def _get_version(ver_str: str) -> Version: OMEGACONF_VERSION: Final = _get_version(omegaconf.__version__) HYDRA_VERSION: Final = _get_version(hydra.__version__) +SUPPORTS_VERSION_BASE = HYDRA_VERSION >= (1, 2, 0) # OmegaConf issue 830 describes a bug associated with structured configs # composed via inheritance, where the child's attribute is a default-factory @@ -70,6 +71,8 @@ def _get_version(ver_str: str) -> Version: range, } +HYDRA_SUPPORTS_LIST_INSTANTIATION = HYDRA_VERSION >= Version(1, 1, 2) + if HYDRA_SUPPORTS_BYTES: # pragma: no cover HYDRA_SUPPORTED_PRIMITIVES.add(bytes) diff --git a/src/hydra_zen/_hydra_overloads.py b/src/hydra_zen/_hydra_overloads.py index 44306910..e41ead27 100644 --- a/src/hydra_zen/_hydra_overloads.py +++ b/src/hydra_zen/_hydra_overloads.py @@ -23,7 +23,7 @@ import pathlib from dataclasses import is_dataclass from functools import wraps -from typing import IO, Any, Callable, Type, TypeVar, Union, cast, overload +from typing import IO, Any, Callable, Dict, List, Type, TypeVar, Union, cast, overload from hydra.utils import instantiate as hydra_instantiate from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf @@ -88,7 +88,15 @@ def instantiate( @overload def instantiate( - config: Union[HasTarget, ListConfig, DictConfig, DataClass_, Type[DataClass_]], + config: Union[ + HasTarget, + ListConfig, + DictConfig, + DataClass_, + Type[DataClass_], + Dict[Any, Any], + List[Any], + ], *args: Any, **kwargs: Any ) -> Any: # pragma: no cover diff --git a/src/hydra_zen/_launch.py b/src/hydra_zen/_launch.py index 5e212e00..888647e7 100644 --- a/src/hydra_zen/_launch.py +++ b/src/hydra_zen/_launch.py @@ -13,7 +13,7 @@ from hydra.types import HydraContext, RunMode from omegaconf import DictConfig, ListConfig, OmegaConf -from hydra_zen._compatibility import HYDRA_VERSION +from hydra_zen._compatibility import SUPPORTS_VERSION_BASE from hydra_zen._hydra_overloads import instantiate from hydra_zen.typing._implementations import DataClass, InstOrType @@ -238,7 +238,7 @@ def launch( job_name=job_name, **( {} - if (HYDRA_VERSION < (1, 2, 0) or version_base is _NotSet) + if (not SUPPORTS_VERSION_BASE or version_base is _NotSet) else {"version_base": version_base} ) ): diff --git a/src/hydra_zen/structured_configs/_type_guards.py b/src/hydra_zen/structured_configs/_type_guards.py index 8111103f..dd485bc8 100644 --- a/src/hydra_zen/structured_configs/_type_guards.py +++ b/src/hydra_zen/structured_configs/_type_guards.py @@ -2,13 +2,14 @@ # SPDX-License-Identifier: MIT # pyright: strict from functools import partial -from typing import Any +from typing import TYPE_CHECKING, Any, Type, Union from typing_extensions import TypeGuard from hydra_zen._compatibility import HYDRA_SUPPORTS_PARTIAL from hydra_zen.funcs import get_obj, zen_processing from hydra_zen.typing import Builds, Just, PartialBuilds +from hydra_zen.typing._implementations import DataClass_ from ._globals import ( JUST_FIELD_NAME, @@ -19,7 +20,7 @@ ZEN_TARGET_FIELD_NAME, ) -__all__ = ["is_partial_builds", "uses_zen_processing"] +__all__ = ["is_partial_builds", "uses_zen_processing", "is_dataclass"] # We need to check if things are Builds, Just, PartialBuilds to a higher # fidelity than is provided by `isinstance(..., )`. I.e. we want to @@ -53,6 +54,15 @@ def is_just(x: Any) -> TypeGuard[Just[Any]]: return False +if TYPE_CHECKING: # pragma: no cover + + def is_dataclass(obj: Any) -> TypeGuard[Union[DataClass_, Type[DataClass_]]]: + ... + +else: + from dataclasses import is_dataclass + + def is_old_partial_builds(x: Any) -> bool: # pragma: no cover # We don't care about coverage here. # This will only be used in `get_target` and we'll be sure to cover that branch diff --git a/src/hydra_zen/wrapper/__init__.py b/src/hydra_zen/wrapper/__init__.py new file mode 100644 index 00000000..55ca40cc --- /dev/null +++ b/src/hydra_zen/wrapper/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2022 Massachusetts Institute of Technology +# SPDX-License-Identifier: MIT + +# pyright: strict + +from ._implementations import Zen, zen + +__all__ = ["zen", "Zen"] diff --git a/src/hydra_zen/wrapper/_implementations.py b/src/hydra_zen/wrapper/_implementations.py new file mode 100644 index 00000000..e4ed9d6e --- /dev/null +++ b/src/hydra_zen/wrapper/_implementations.py @@ -0,0 +1,630 @@ +# Copyright (c) 2022 Massachusetts Institute of Technology +# SPDX-License-Identifier: MIR +# pyright: strict + +from inspect import Parameter, signature +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +import hydra +from hydra.main import _UNSPECIFIED_ # type: ignore +from omegaconf import DictConfig, ListConfig, OmegaConf +from typing_extensions import Literal, ParamSpec, TypeAlias, TypeGuard + +from hydra_zen import instantiate +from hydra_zen.errors import HydraZenValidationError +from hydra_zen.typing._implementations import DataClass_ + +from .._compatibility import HYDRA_SUPPORTS_LIST_INSTANTIATION, SUPPORTS_VERSION_BASE +from ..structured_configs._type_guards import is_dataclass +from ..structured_configs._utils import safe_name + +__all__ = ["zen"] + + +R = TypeVar("R") +P = ParamSpec("P") + + +if HYDRA_SUPPORTS_LIST_INSTANTIATION: + _SUPPORTED_INSTANTIATION_TYPES: Tuple[Any, ...] = (dict, DictConfig, list, ListConfig) # type: ignore +else: # pragma: no cover + _SUPPORTED_INSTANTIATION_TYPES: Tuple[Any, ...] = (dict, DictConfig) # type: ignore + +ConfigLike: TypeAlias = Union[ + DataClass_, + Type[DataClass_], + Dict[Any, Any], + DictConfig, +] + + +def is_instantiable( + cfg: Any, +) -> TypeGuard[ConfigLike]: + return is_dataclass(cfg) or isinstance(cfg, _SUPPORTED_INSTANTIATION_TYPES) + + +SKIPPED_PARAM_KINDS = frozenset( + (Parameter.POSITIONAL_ONLY, Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL) +) + + +PreCall = Optional[Union[Callable[[Any], Any], Iterable[Callable[[Any], Any]]]] + + +def _flat_call(x: Iterable[Callable[P, Any]]) -> Callable[P, None]: + def f(*args: P.args, **kwargs: P.kwargs) -> None: + for fn in x: + fn(*args, **kwargs) + + return f + + +class Zen(Generic[P, R]): + """Implements the decorator logic that is exposed by `hydra_zen.zen` + + Attributes + ---------- + CFG_NAME : str + The reserved parameter name specifies to pass the input config through + to the inner function. Can be overwritted via subclassing. Defaults + to 'zen_cfg' + + See Also + -------- + zen : A decorator that returns a function that will auto-extract, resolve, and instantiate fields from an input config based on the decorated function's signature. + """ + + # Specifies reserved parameter name specified to pass the + # config through to the task function + CFG_NAME: str = "zen_cfg" + + def __repr__(self) -> str: + return f"zen[{(safe_name(self.func))}({', '.join(self.parameters)})](cfg, /)" + + def __init__( + self, + __func: Callable[P, R], + *, + exclude: Optional[Union[str, Iterable[str]]] = None, + pre_call: PreCall = None, + unpack_kwargs: bool = False, + ) -> None: + """ + Parameters + ---------- + func : Callable[Sig, R] + The function being decorated. (This is a positional-only argument) + + unpack_kwargs: bool, optional (default=False) + If `True` a `**kwargs` field in the decorated function's signature will be + populated by all of the input config entries that are not specified by the + rest of the signature (and that are not specified by the `exclude` + argument). + + pre_call : Optional[Callable[[Any], Any] | Iterable[Callable[[Any], Any]]] + One or more functions that will be called with the input config prior + to the decorated functions. An iterable of pre-call functions are called + from left (low-index) to right (high-index). + + exclude: Optional[str | Iterable[str]] + Specifies one or more parameter names in the function's signature + that will not be extracted from input configs by the zen-wrapped function. + + A single string of comma-separated names can be specified. + """ + self.func: Callable[P, R] = __func + + try: + self.parameters: Mapping[str, Parameter] = signature(self.func).parameters + except (ValueError, TypeError): + raise HydraZenValidationError( + "hydra_zen.zen can only wrap callables that possess inspectable signatures." + ) + + if not isinstance(unpack_kwargs, bool): # type: ignore + raise TypeError(f"`unpack_kwargs` must be type `bool` got {unpack_kwargs}") + + self._unpack_kwargs: bool = unpack_kwargs and any( + p.kind is p.VAR_KEYWORD for p in self.parameters.values() + ) + + self._exclude: Set[str] + + if exclude is None: + self._exclude = set() + elif isinstance(exclude, str): + self._exclude = {k.strip() for k in exclude.split(",")} + else: + self._exclude = set(exclude) + + if self.CFG_NAME in self.parameters: + self._has_zen_cfg = True + self.parameters = { + name: param + for name, param in self.parameters.items() + if name != self.CFG_NAME + } + else: + self._has_zen_cfg = False + + self._pre_call_iterable = ( + (pre_call,) if not isinstance(pre_call, Iterable) else pre_call + ) + + # validate pre-call signatures + for _f in self._pre_call_iterable: + if _f is None: + continue + + _f_params = signature(_f).parameters + + if (sum(p.default is p.empty for p in _f_params.values()) > 1) or len( + _f_params + ) == 0: + raise HydraZenValidationError( + f"pre_call function {_f} must be able to accept a single " + "positional argument" + ) + + self.pre_call: Optional[Callable[[Any], Any]] = ( + pre_call if not isinstance(pre_call, Iterable) else _flat_call(pre_call) + ) + + @staticmethod + def _normalize_cfg( + cfg: Union[ + DataClass_, + Type[DataClass_], + Dict[Any, Any], + List[Any], + ListConfig, + DictConfig, + str, + ], + ) -> DictConfig: + if is_dataclass(cfg): + # ensures that default factories and interpolated fields + # are resolved + cfg = OmegaConf.structured(cfg) + + elif not OmegaConf.is_config(cfg): + if not isinstance(cfg, (dict, str)): + raise HydraZenValidationError( + f"`cfg` must be a dataclass, dict/DictConfig, or " + f"dict-style yaml-string. Got {cfg}" + ) + cfg = OmegaConf.create(cfg) + + if not isinstance(cfg, DictConfig): + raise HydraZenValidationError( + f"`cfg` must be a dataclass, dict/DictConfig, or " + f"dict-style yaml-string. Got {cfg}" + ) + return cfg + + def validate(self, __cfg: Union[ConfigLike, str]) -> None: + """Validates the input config based on the decorated function without calling said function. + + Parameters + ---------- + cfg : dict | list | DataClass | Type[DataClass] | str + (positional only) A config object or yaml-string whose attributes will be + checked according to the signature of `func`. + + Raises + ------ + HydraValidationError + `cfg` is not a valid input to the zen-wrapped function. + """ + for _f in self._pre_call_iterable: + if isinstance(_f, Zen): + _f.validate(__cfg) + + cfg = self._normalize_cfg(__cfg) + + num_pos_only = sum( + p.kind is p.POSITIONAL_ONLY for p in self.parameters.values() + ) + + _args_ = getattr(cfg, "_args_", []) + + if not isinstance(_args_, Sequence): + raise HydraZenValidationError( + f"`cfg._args_` must be a sequence type (e.g. a list), got {_args_}" + ) + if num_pos_only and len(_args_) != num_pos_only: + raise HydraZenValidationError( + f"{self.func} has {num_pos_only} positional-only arguments, but " + f"`cfg` specifies {len(getattr(cfg, '_args_', []))} positional " + f"arguments via `_args_`." + ) + + missing_params: List[str] = [] + for name, param in self.parameters.items(): + if name in self._exclude: + continue + + if param.kind in SKIPPED_PARAM_KINDS: + continue + + if not hasattr(cfg, name) and param.default is param.empty: + missing_params.append(name) + + if missing_params: + raise HydraZenValidationError( + f"`cfg` is missing the following fields: {', '.join(missing_params)}" + ) + + def __call__(self, __cfg: Union[ConfigLike, str]) -> R: + """ + Extracts values from the input config based on the decorated function's + signature, resolves & instantiates them, and calls the function with them. + + Parameters + ---------- + cfg : dict | DataClass | Type[DataClass] | str + (positional only) A config object or yaml-string whose attributes will be + extracted by-name according to the signature of `func` and passed to `func`. + + Attributes of types that can be instantiated by Hydra will be instantiated + prior to being passed to `func`. + + Returns + ------- + func_out : R + The result of `func()` + """ + cfg = self._normalize_cfg(__cfg) + # resolves all interpolated values in-place + OmegaConf.resolve(cfg) + + if self.pre_call is not None: + self.pre_call(cfg) + + args_ = list(getattr(cfg, "_args_", [])) + + cfg_kwargs = { + name: ( + getattr(cfg, name, param.default) + if param.default is not param.empty + else getattr(cfg, name) + ) + for name, param in self.parameters.items() + if param.kind not in SKIPPED_PARAM_KINDS and name not in self._exclude + } + + extra_kwargs = {self.CFG_NAME: cfg} if self._has_zen_cfg else {} + if self._unpack_kwargs: + names = ( + name + for name in cfg + if name not in cfg_kwargs + and name not in self._exclude + and isinstance(name, str) + ) + cfg_kwargs.update({name: cfg[name] for name in names}) + return self.func( + *(instantiate(x) if is_instantiable(x) else x for x in args_), + **{ + name: instantiate(val) if is_instantiable(val) else val + for name, val in cfg_kwargs.items() + }, + **extra_kwargs, + ) # type: ignore + + def hydra_main( + self, + config_path: Optional[str] = _UNSPECIFIED_, + config_name: Optional[str] = None, + version_base: Optional[str] = _UNSPECIFIED_, + ) -> Callable[[Any], Any]: + """ + Returns a Hydra-CLI compatible version of the wrapped function: `hydra.main(zen(func))` + + Parameters + ---------- + config_path : Optional[str] + The config path, a directory relative to the declaring python file. + + If config_path is not specified no directory is added to the Config search path. + + config_name : Optional[str] + The name of the config (usually the file name without the .yaml extension) + + version_base : Optional[str] + There are three classes of values that the version_base parameter supports, given new and existing users greater control of the default behaviors to use. + + - If the version_base parameter is not specified, Hydra 1.x will use defaults compatible with version 1.1. Also in this case, a warning is issued to indicate an explicit version_base is preferred. + - If the version_base parameter is None, then the defaults are chosen for the current minor Hydra version. For example for Hydra 1.2, then would imply config_path=None and hydra.job.chdir=False. + - If the version_base parameter is an explicit version string like "1.1", then the defaults appropriate to that version are used. + + Returns + ------- + hydra_main : Callable[[Any], Any] + Equivalent to `hydra.main(zen(func), [...])()` + """ + + kw = dict(config_path=config_path, config_name=config_name) + if SUPPORTS_VERSION_BASE: + kw["version_base"] = version_base + + return hydra.main(**kw)(self)() + + +@overload +def zen( + __func: Callable[P, R], + *, + unpack_kwargs: bool = ..., + pre_call: PreCall = ..., + ZenWrapper: Type[Zen[P, R]] = Zen, + exclude: Optional[Union[str, Iterable[str]]] = None, +) -> Zen[P, R]: # pragma: no cover + ... + + +@overload +def zen( + __func: Literal[None] = ..., + *, + unpack_kwargs: bool = ..., + pre_call: PreCall = ..., + ZenWrapper: Type[Zen[Any, Any]] = ..., + exclude: Optional[Union[str, Iterable[str]]] = None, +) -> Callable[[Callable[P, R]], Zen[P, R]]: # pragma: no cover + ... + + +def zen( + __func: Optional[Callable[P, R]] = None, + *, + unpack_kwargs: bool = False, + pre_call: PreCall = None, + exclude: Optional[Union[str, Iterable[str]]] = None, + ZenWrapper: Type[Zen[P, R]] = Zen, +) -> Union[Zen[P, R], Callable[[Callable[P, R]], Zen[P, R]]]: + r"""zen(func, /, pre_call, ZenWrapper) + + A decorator that returns a function that will auto-extract, resolve, and + instantiate fields from an input config based on the decorated function's signature. + + .. code-block:: pycon + + >>> Cfg = dict(x=1, y=builds(int, 4), z="${y}", unused=100) + >>> zen(lambda x, y, z : x+y+z)(Cfg) # x=1, y=4, z=4 + 9 + + The main purpose of `zen` is to enable a user to write/use Hydra-agnostic functions + as the task functions for their Hydra app. + + Parameters + ---------- + func : Callable[Sig, R] + The function being decorated. (This is a positional-only argument) + + unpack_kwargs: bool, optional (default=False) + If `True` a `**kwargs` field in the decorated function's signature will be + populated by all of the input config entries that are not specified by the rest + of the signature (and that are not specified by the `exclude` argument). + + pre_call : Optional[Callable[[Any], Any] | Iterable[Callable[[Any], Any]]] + One or more functions that will be called with the input config prior + to the decorated functions. An iterable of pre-call functions are called + from left (low-index) to right (high-index). + + This is useful, e.g., for seeding a RNG prior to the instantiation phase + that is triggered when calling the decorated function. + + exclude: Optional[str | Iterable[str]] + Specifies one or more parameter names in the function's signature + that will not be extracted from input configs by the zen-wrapped function. + + A single string of comma-separated names can be specified. + + ZenWrapper : Type[hydra_zen.wrapper.Zen], optional (default=Zen) + If specified, a subclass of `Zen` that customizes the behavior of the wrapper. + + Returns + ------- + Zen[Sig, R] + A callable with signature `(conf: ConfigLike, \\) -> R` + + The wrapped function is an instance of `hydra_zen.wrapper.Zen` and accepts + a single Hydra config (a dataclass, dictionary, or omegaconf container). The + parameters of the decorated function's signature determine the fields that are + extracted from the config; only those fields that are accessed will be resolved + and instantiated. + + Notes + ----- + ConfigLike is DataClass | list[Any] | dict[str, Any] | DictConfig | ListConfig + + The fields extracted from the input config are determined by the signature of the + decorated function. There is an exception: including a parameter named "zen_cfg" + in the function's signature will signal to `zen` to pass through the full config to + that field (This specific parameter name can be overridden via `Zen.CFG_NAME`). + + All values (extracted from the input config) of types belonging to ConfigLike will + be instantiated before being passed to the wrapped function. + + Examples + -------- + **Basic Usage** + + Using `zen` as a decorator + + >>> from hydra_zen import zen, make_config, builds + >>> @zen + ... def f(x, y): return x + y + + The resulting decorated function accepts a single argument: a Hydra-compatible config that has the attributes "x" and "y": + + >>> f + zen[f(x, y)](cfg, /) + + "Configs" – dataclasses, dictionaries, and omegaconf containers – are acceptable + inputs to zen-wrapped functions. Interpolated fields will be resolved and + sub-configs will be instantiated. Excess fields in the config are unused. + + >>> f(make_config(x=1, y=2, z=999)) # z is not used + 3 + >>> f(dict(x=2, y="${x}")) # y will resolve to 2 + 4 + >>> f(dict(x=2, y=builds(int, 10))) # y will instantiate to 10 + 12 + + The wrapped function can be accessed directly + + >>> f.func + + >>> f.func(-1, 1) + 0 + + `zen` is compatible with partial'd functions. + + >>> from functools import partial + >>> pf = partial(lambda x, y: x + y, x=10) + >>> zpf = zen(pf) + >>> zpf(dict(y=1)) + 11 + >>> zpf(dict(x='${y}', y=1)) + 2 + + One can specify `exclude` to prevent particular variables from being extracted from + a config: + + >>> def f(x=1, y=2): return (x, y) + >>> zen(f)(dict(x=-10, y=-20)) # extracts x & y from config to call f + (-10, -20) + >>> zen(f, exclude="x")(dict(x=-10, y=-20)) # extracts y from config to call f(x=1, ...) + (1, -20) + >>> zen(f, exclude="x,y")(dict(x=-10, y=-20)) # defers to f's defaults + (1, 2) + + Populating a `**kwargs` field via `unpack_kwargs=True`: + + >>> def f(a, **kw): + ... return a, kw + + >>> cfg = dict(a=1, b=22) + >>> zen(f, unpack_kwargs=False)(cfg) + (1, {}) + >>> zen(f, unpack_kwargs=True)(cfg) + (1, {'b': 22}) + + **Including a pre-call function** + + Given that a zen-wrapped function will automatically extract and instantiate config + fields upon being called, it can be necessary to include a pre-call step that + occurs prior to any instantiation. `zen` can be passed one or more pre-call + functions that will be called with the input config as a precursor to calling the + decorated function. + + Consider the following scenario where the config's instantiation involves drawing a + random value, which we want to be made deterministic with a configurable seed. We + will use a pre-call function to seed the RNG prior to the subsequent instantiation. + + >>> import random + >>> from hydra_zen import builds, zen + + >>> Cfg = dict( + ... # `rand_val` will be instantiated and draw from randint upon + ... # calling the zen-wrapped function, thus we need a pre-call + ... # function to set the global RNG seed prior to instantiation + ... rand_val=builds(random.randint, 0, 10), + ... seed=0, + ... ) + + >>> @zen(pre_call=lambda cfg: random.seed(cfg.seed)) + ... def f(rand_val: int): + ... return rand_val + + >>> [f(Cfg) for _ in range(10)] + [6, 6, 6, 6, 6, 6, 6, 6, 6, 6] + + + **Using `@zen` instead of `@hydra.main`** + + The object returned by zen provides a convenience method – `Zen.hydra_main` – so + that users need not double-wrap with `@hydra.main` to create a CLI: + + .. code-block:: python + + # example.py + from hydra.core.config_store import ConfigStore + + from hydra_zen import builds, zen + + def task(x: int, y: int): + print(x + y) + + cs = ConfigStore.instance() + cs.store(name="my_app", node=builds(task, populate_full_signature=True)) + + + if __name__ == "__main__": + zen(task).hydra_main(config_name="my_app", config_path=None) + + .. code-block:: console + + $ python example.py x=1 y=2 + 3 + + + **Validating input configs** + + An input config can be validated against a zen-wrapped function without calling said function via the `.validate` method. + + >>> def f(x: int): ... + >>> zen_f = zen(f) + >>> zen_f.validate({"x": 1}) # OK + >>> zen_f.validate({"y": 1}) # Missing x + HydraZenValidationError: `cfg` is missing the following fields: x + + Validation propagates through zen-wrapped pre-call functions: + + >>> zen_f = zen(f, pre_call=zen(lambda seed: None)) + >>> zen_f.validate({"x": 1, "seed": 10}) # OK + >>> zen_f.validate({"x": 1}) # Missing seed as required by pre-call + HydraZenValidationError: `cfg` is missing the following fields: seed + + **Passing Through The Config** + + Some task functions require complete access to the full config to gain access to + sub-configs. One can specify the field named `zen_config` in their task function's + signature to signal `zen` that it should pass the full config to that parameter . + + >>> @zen + ... def f(x: int, zen_cfg): + ... return x, zen_cfg + >>> f(dict(x=1, y="${x}")) + (1, {'x': 1, 'y': 1}) + """ + if __func is not None: + return ZenWrapper( + __func, pre_call=pre_call, exclude=exclude, unpack_kwargs=unpack_kwargs + ) + + def wrap(f: Callable[P, R]) -> Zen[P, R]: + return ZenWrapper( + f, pre_call=pre_call, exclude=exclude, unpack_kwargs=unpack_kwargs + ) + + return wrap diff --git a/tests/annotations/declarations.py b/tests/annotations/declarations.py index 0bd315fa..56c17792 100644 --- a/tests/annotations/declarations.py +++ b/tests/annotations/declarations.py @@ -21,7 +21,7 @@ from typing import Any, Callable, List, Mapping, Optional, Tuple, Type, TypeVar, Union from omegaconf import MISSING, DictConfig, ListConfig -from typing_extensions import Literal +from typing_extensions import Literal, assert_type from hydra_zen import ( ZenField, @@ -32,6 +32,7 @@ make_config, make_custom_builds_fn, mutable_value, + zen, ) from hydra_zen.structured_configs._value_conversion import ConfigComplex, ConfigPath from hydra_zen.typing import ( @@ -45,6 +46,7 @@ ) from hydra_zen.typing._builds_overloads import FullBuilds, PBuilds, StdBuilds from hydra_zen.typing._implementations import DataClass_, HydraPartialBuilds +from hydra_zen.wrapper import Zen T = TypeVar("T") @@ -1031,3 +1033,103 @@ def f(x: InstOrType[DataClass]): f(Xonf) launch(Xonf, f) + + +def check_instantiate(): + @dataclass + class Cfg: + ... + + assert_type(instantiate(DictConfig({})), Any) + assert_type(instantiate({}), Any) + assert_type(instantiate(ListConfig([])), Any) + assert_type(instantiate([]), Any) + assert_type(instantiate(Cfg), Any) + assert_type(instantiate(Cfg()), Any) + + +def check_zen(): + @zen + def zen_f(x: int) -> str: + ... + + assert_type(zen_f({"a": 1}), str) + assert_type(zen_f(DictConfig({"a": 1})), str) + assert_type(zen_f("some yaml"), str) + + assert_type(zen_f([])) # type: ignore + assert_type(zen_f(ListConfig([]))) # type: ignore + + zen_f(1) # type: ignore + reveal_type(zen_f.func, expected_text="(x: int) -> str") + + @zen(pre_call=None) + def zen_f2(x: int) -> str: + ... + + assert_type(zen_f2({"a": 1}), str) + assert_type(zen_f2(DictConfig({"a": 1})), str) + assert_type(zen_f2("some yaml"), str) + + zen_f2(1) # type: ignore + reveal_type(zen_f2.func, expected_text="(x: int) -> str") + + class MyZen(Zen): + ... + + @zen(ZenWrapper=MyZen) + def zen_rewrapped(x: int) -> str: + ... + + reveal_type(zen_rewrapped, expected_text="Zen[(x: int), str]") + + @zen(unpack_kwargs=True) + def unpacks_kw(**kw): + ... + + def f(x: int): + ... + + zen_rewrapped2 = zen(f, ZenWrapper=MyZen) + + reveal_type(zen_rewrapped2, expected_text="Zen[(x: int), None]") + + # valid pre-call + @zen(pre_call=lambda cfg: None) + def h1(): + ... + + @zen(pre_call=[lambda cfg: None]) + def h2(): + ... + + @zen(pre_call=zen(lambda x, y: None)) + def h3(): + ... + + # bad pre-call + + @zen(pre_call=1) # type: ignore + def g1(): + ... + + @zen(pre_call=lambda x, y: None) # type: ignore + def g2(): + ... + + @zen(pre_call=[lambda x, y: None]) # type: ignore + def g3(): + ... + + # valid excludes + @zen(exclude="a") + def p1(): + ... + + @zen(exclude=("a" for _ in range(1))) + def p2(): + ... + + @zen(exclude=1) # type: ignore + def p3(): + ... diff --git a/tests/conftest.py b/tests/conftest.py index 6578126f..ee0a7d95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,14 +24,14 @@ "pytorch_lightning", "pydantic", "beartype", + "submitit", ) _installed = {pkg.key for pkg in pkg_resources.working_set} for _module_name in OPTIONAL_TEST_DEPENDENCIES: if _module_name not in _installed: - collect_ignore_glob.append(f"**/*{_module_name}*.py") - + collect_ignore_glob.append(f"*{_module_name}*.py") if sys.version_info > (3, 6): collect_ignore_glob.append("*py36*") diff --git a/tests/custom_strategies.py b/tests/custom_strategies.py index 51eca86a..82827df1 100644 --- a/tests/custom_strategies.py +++ b/tests/custom_strategies.py @@ -7,7 +7,7 @@ from hydra_zen.structured_configs._utils import get_obj_path from hydra_zen.typing._implementations import ZenConvert -__all__ = ["valid_builds_args", "partitions"] +__all__ = ["valid_builds_args", "partitions", "everything_except"] _Sequence = Union[List, Tuple, Deque] T = TypeVar("T", bound=Union[_Sequence, Dict[str, Any]]) @@ -100,3 +100,11 @@ def partitions( if isinstance(collection, st.SearchStrategy): return collection.flatmap(lambda x: _partition(x, ordered=ordered)) # type: ignore return cast(st.SearchStrategy[Tuple[T, T]], _partition(collection, ordered)) + + +def everything_except(excluded_types): + return ( + st.from_type(type) + .flatmap(st.from_type) + .filter(lambda x: not isinstance(x, excluded_types)) + ) diff --git a/tests/dummy_zen_main.py b/tests/dummy_zen_main.py new file mode 100644 index 00000000..3ab24628 --- /dev/null +++ b/tests/dummy_zen_main.py @@ -0,0 +1,21 @@ +import random + +from hydra.core.config_store import ConfigStore + +from hydra_zen import make_config, zen + +cs = ConfigStore.instance() +cs.store(name="my_app", node=make_config("x", "y", z="${y}", seed=12)) + + +@zen(pre_call=lambda cfg: random.seed(cfg.seed)) +def f(x: int, y: int, z: int): + pre_seeded = random.randint(0, 10) + random.seed(12) + seeded = random.randint(0, 10) + assert pre_seeded == seeded + return x + y + z + + +if __name__ == "__main__": + f.hydra_main(config_name="my_app", config_path=None) diff --git a/tests/test_validation_py38.py b/tests/test_validation_py38.py index c698ddb2..26359f04 100644 --- a/tests/test_validation_py38.py +++ b/tests/test_validation_py38.py @@ -8,7 +8,8 @@ from hypothesis import given from omegaconf import OmegaConf -from hydra_zen import builds, instantiate, to_yaml +from hydra_zen import builds, instantiate, make_config, to_yaml, zen +from hydra_zen.errors import HydraZenValidationError from tests import valid_hydra_literals @@ -20,6 +21,19 @@ def xy_are_pos_only(x, y, /): return x, y +def test_zen_decorator_with_positional_only_args(): + zen_f = zen(x_is_pos_only) + with pytest.raises( + HydraZenValidationError, + match=r"has 1 positional-only arguments, but `cfg` specifies 0", + ): + zen_f.validate(make_config()) + + cfg = builds(x_is_pos_only, 1) + zen_f.validate(cfg) + assert zen_f(cfg) == 1 + + @pytest.mark.parametrize("func", [x_is_pos_only]) @given(partial=st.none() | st.booleans(), full_sig=st.booleans()) def test_builds_runtime_validation_pos_only_not_nameable(func, full_sig, partial): diff --git a/tests/test_with_hydra_submitit.py b/tests/test_with_hydra_submitit.py new file mode 100644 index 00000000..1bdccf6b --- /dev/null +++ b/tests/test_with_hydra_submitit.py @@ -0,0 +1,40 @@ +import os +import pickle +import sys + +import pytest + + +@pytest.mark.skipif( + sys.platform.startswith("win") and bool(os.environ.get("CI")), + reason="Things are weird on GitHub Actions and Windows", +) +@pytest.mark.usefixtures("cleandir") +def test_pickling_with_hydra_main(): + """This test uses hydra-submitit-launcher because + submitit uses cloudpickle to pickle the task function + and execute a job from the pickled task function.""" + import subprocess + from pathlib import Path + + path = (Path(__file__).parent / "dummy_zen_main.py").absolute() + assert not (Path.cwd() / "multirun").is_dir() + subprocess.run( + ["python", path, "x=1", "y=2", "hydra/launcher=submitit_local", "--multirun"] + ).check_returncode() + assert (Path.cwd() / "multirun").is_dir() + + multirun_files = list(Path.cwd().glob("**/multirun.yaml")) + assert len(multirun_files) == 1 + multirun_file = multirun_files[0] + assert (multirun_file.parent / ".submitit").is_dir() + + # load the results saved by submitit + pkls = list((multirun_file.parent / ".submitit").glob("**/*_result.pkl")) + assert len(pkls) == 1 + with open(pkls[0], "rb") as f: + result = pickle.load(f) + + # assert the result is correct and the task function executed + assert result[0] == "success" + assert result[1].return_value == 5 diff --git a/tests/test_zen_decorator.py b/tests/test_zen_decorator.py new file mode 100644 index 00000000..08bc5495 --- /dev/null +++ b/tests/test_zen_decorator.py @@ -0,0 +1,484 @@ +# Copyright (c) 2022 Massachusetts Institute of Technology +# SPDX-License-Identifier: MIT + +import os +import random +import sys +from dataclasses import dataclass +from typing import Any, Tuple + +import pytest +from hypothesis import example, given, strategies as st +from omegaconf import DictConfig + +from hydra_zen import builds, make_config, to_yaml, zen +from hydra_zen.errors import HydraZenValidationError +from hydra_zen.wrapper import Zen +from tests.custom_strategies import everything_except + + +@zen +def zen_identity(x: int): + return x + + +def function(x: int, y: int, z: int = 2): + return x * y * z + + +def function_with_args(x: int, y: int, z: int = 2, *args): + return x * y * z + + +def function_with_kwargs(x: int, y: int, z: int = 2, **kwargs): + return x * y * z + + +def function_with_args_kwargs(x: int, y: int, z: int = 2, *args, **kwargs): + return x * y * z + + +def test_zen_basic_usecase(): + @zen + def f(x: int, y: str): + return x * y + + Cfg = make_config(x=builds(int, 2), y="cow", unused="unused") + assert f(Cfg) == 2 * "cow" + + +def test_zen_repr(): + assert repr(zen(lambda x, y: None)) == "zen[(x, y)](cfg, /)" + assert ( + repr(zen(pre_call=lambda x: x)(lambda x, y: None)) + == "zen[(x, y)](cfg, /)" + ) + assert repr(zen(make_config("x", "y"))) == "zen[Config(x, y)](cfg, /)" + + +@pytest.mark.parametrize( + "exclude,expected", + [ + (None, (1, 0)), + ("y", (1, 2)), + ("y,", (1, 2)), + (["y"], (1, 2)), + ("x,y", (-2, 2)), + (["x", "y"], (-2, 2)), + ], +) +@given(unpack_kw=st.booleans()) +def test_zen_excluded_param(exclude, expected, unpack_kw): + zenf = zen(lambda x=-2, y=2, **kw: (x, y), exclude=exclude, unpack_kwargs=unpack_kw) + conf = dict(x=1, y=0) + assert zenf(conf) == expected + + +@pytest.mark.parametrize( + "target", + [ + zen(function), + zen(function_with_args), + zen(function_with_kwargs), + zen(function_with_args_kwargs), + zen_identity, + zen(lambda x: x), + ], +) +def test_repr_doesnt_crash(target): + assert isinstance(repr(target), str) + + +@pytest.mark.parametrize("precall", [None, lambda x: x]) +def test_zen_wrapper_trick(precall): + def f(x): + return x + + # E.g. @zen + # def f(...) + z1 = zen(f, pre_call=precall) # direct wrap + assert isinstance(z1, Zen) and z1.func is f and z1.pre_call is precall + + # E.g. @zen(pre_call=...) + # def f(...) + z2 = zen(pre_call=precall)(f) # config then wrap + assert isinstance(z2, Zen) and z2.func is f and z1.pre_call is precall + + +@given(seed=st.sampled_from(range(4))) +def test_zen_pre_call_precedes_instantiation(seed: int): + @zen(pre_call=zen(lambda seed: random.seed(seed))) + def f(x: int, y: int): + return x, y + + actual = f( + make_config( + x=builds(int, 2), + y=builds(random.randint, 0, 10), + seed=seed, + ) + ) + + random.seed(seed) + expected = f.func(2, random.randint(0, 10)) + + assert actual == expected + + +@given(y=st.sampled_from(range(10)), as_dict_config=st.booleans()) +def test_interpolations_are_resolved(y: int, as_dict_config: bool): + @zen(unpack_kwargs=True) + def f(dict_, list_, builds_, make_config_, direct, **kw): + return dict_, list_, builds_, make_config_, direct, kw["nested"] + + cfg_maker = make_config if not as_dict_config else lambda **kw: DictConfig(kw) + out = f( + cfg_maker( + dict_={"x": "${y}"}, + list_=["${y}"], + builds_=builds(dict, a="${y}"), + make_config_=make_config(b="${y}"), + direct="${y}", + nested=dict(top=dict(bottom="${...y}")), + y=y, + ) + ) + assert out == ({"x": y}, [y], {"a": y}, {"b": y}, y, {"top": {"bottom": y}}) + + +@pytest.mark.parametrize( + "cfg", + [ + dict(x=21), + DictConfig(dict(x=21)), + make_config(x=21), + to_yaml(dict(x=21)), + ], +) +def test_supported_config_types(cfg): + @zen + def f(x): + return x + + assert f(cfg) == 21 + + +def test_zen_resolves_default_factories(): + Cfg = make_config(x=[1, 2, 3]) + assert zen_identity(Cfg) == [1, 2, 3] + assert zen_identity(Cfg()) == [1, 2, 3] + + +def test_zen_works_on_partiald_funcs(): + from functools import partial + + def f(x: int, y: str): + return x, y + + zen_pf = zen(partial(f, x=1)) + + with pytest.raises( + HydraZenValidationError, + match=r"`cfg` is missing the following fields: y", + ): + zen_pf.validate(make_config()) + + zen_pf.validate(make_config(y="a")) + assert zen_pf(make_config(y="a")) == (1, "a") + zen_pf.validate(make_config(x=2, y="a")) + assert zen_pf(make_config(x=2, y="a")) == (2, "a") + + +@given(x=st.integers(-5, 5), y=st.integers(-5, 5), unpack_kw=st.booleans()) +def test_zen_cfg_passthrough(x: int, y: int, unpack_kw: bool): + @zen(unpack_kwargs=unpack_kw) + def f(x: int, zen_cfg, **kw): + return (x, zen_cfg) + + x_out, cfg = f({"x": x, "y": y, "z": "${y}"}) + assert x_out == x + assert cfg == {"x": x, "y": y, "z": y} + + +@given(x=st.integers(-5, 5), wrap_mode=st.sampled_from(["decorator", "inline"])) +def test_custom_zen_wrapper(x, wrap_mode): + class MyZen(Zen): + CFG_NAME: str = "secret_cfg" + + def __call__(self, cfg) -> Tuple[Any, str]: + return (super().__call__(cfg), "moo") + + if wrap_mode == "decorator": + + @zen(ZenWrapper=MyZen) + def f(x: int, secret_cfg): + return x, secret_cfg + + elif wrap_mode == "inline": + + def _f(x: int, secret_cfg): + return x, secret_cfg + + f = zen(_f, ZenWrapper=MyZen) + else: + assert False + + cfg = {"x": x} + f.validate(cfg) + + (out_x1, cfg), moo = f(cfg) # type: ignore + assert out_x1 == x + assert cfg == {"x": x} + assert moo == "moo" + + +class A: + def f(self, x: int, y: int, z: int = 2): + return x * y * z + + +method = A().f + + +@pytest.mark.parametrize( + "func", + [ + function, + function_with_args, + function_with_kwargs, + function_with_args_kwargs, + method, + ], +) +@pytest.mark.parametrize( + "cfg", + [ + make_config(), # missing x & y + make_config(not_a_field=2), + make_config(x=1), # missing y + make_config(y=2, z=4), # missing x + ], +) +def test_zen_validation_cfg_missing_parameter(cfg, func): + with pytest.raises( + HydraZenValidationError, + match=r"`cfg` is missing the following fields", + ): + zen(func).validate(cfg) + + +@example(to_yaml(["a", "b"])) +@example(["a", "b"]) +@given(bad_config=everything_except((dict, str))) +def test_zen_validate_bad_config(bad_config): + @zen + def f(*a, **k): + ... + + with pytest.raises( + HydraZenValidationError, + match=r"`cfg` must be a ", + ): + f(bad_config) + + +def test_validate_unpack_kwargs(): + with pytest.raises(TypeError, match=r"`unpack_kwargs` must be type `bool`"): + zen(lambda a: None, unpack_kwargs="apple") # type: ignore + + +def test_zen_validation_excluded_param(): + zen(lambda x: ..., exclude="x").validate(make_config()) + + +def test_zen_validation_cfg_has_bad_pos_args(): + def f(x): + return x + + @dataclass + class BadCfg: + _args_: int = 1 # bad _args_ + + with pytest.raises( + HydraZenValidationError, + match=r"`cfg._args_` must be a sequence type", + ): + zen(f).validate(BadCfg) + + +@pytest.mark.parametrize("not_inspectable", [range(1), False]) +def test_zen_validate_no_sig(not_inspectable): + with pytest.raises( + HydraZenValidationError, + match="hydra_zen.zen can only wrap callables that possess inspectable signatures", + ): + zen(not_inspectable) + + +@pytest.mark.parametrize( + "bad_pre_call", [lambda: None, lambda x, y: None, lambda x, y, z=1: None] +) +def test_pre_call_validates_wrong_num_args(bad_pre_call): + with pytest.raises( + HydraZenValidationError, + match=r"must be able to accept a single positional argument", + ): + zen( + lambda x: None, + pre_call=bad_pre_call, + ) + + +def test_pre_call_validates_bad_param_name(): + with pytest.raises( + HydraZenValidationError, + match=r"`cfg` is missing the following fields", + ): + zen( + lambda x: None, + pre_call=zen(lambda missing: None), + ).validate({"x": 1}) + + +@pytest.mark.parametrize( + "func", + [ + function, + function_with_args, + function_with_kwargs, + function_with_args_kwargs, + method, + ], +) +@given( + x=st.integers(-10, 10), + y=st.integers(-10, 10), + # kwargs=st.dictionaries(st.sampled_from(["z", "not_a_field"]), st.integers()), + instantiate_cfg=st.booleans(), + unpack_kw=st.booleans(), +) +def test_zen_call(x: int, y: int, instantiate_cfg, func, unpack_kw): + + cfg = make_config(x=x, y=y) + if instantiate_cfg: + cfg = cfg() + + # kwargs.pop("not_a_field", None) + expected = func(x, y) + actual = zen(func, unpack_kwargs=unpack_kw)(cfg) + assert expected == actual + + +@given(x=st.sampled_from(range(10)), unpack_kw=st.booleans()) +def test_zen_function_respects_with_defaults(x, unpack_kw: bool): + @zen(unpack_kwargs=unpack_kw) + def f(x: int = 2, **kw): + return x + + assert f(make_config()) == 2 # defer to x's default + assert f(make_config(x=x)) == x # overwrite x + + +def raises(): + raise AssertionError("shouldn't have been called!") + + +@pytest.mark.parametrize( + "call", + [ + lambda x: zen_identity(make_config(x=builds(int, x))), + lambda x: zen_identity(make_config(x=builds(int, x), y=builds(raises))), + lambda x: zen(lambda x: x, unpack_kwargs=True)( + dict(x=builds(int, x), y=builds(raises)) + ), + lambda x: zen(lambda **kw: kw["x"], unpack_kwargs=True, exclude="y")( + dict(x=builds(int, x), y=builds(raises)) + ), + ], +) +@given(x=st.sampled_from(range(10))) +def test_instantiation_only_occurs_as_needed(call, x): + assert call(x) == x + + +def test_zen_works_with_non_builds(): + bigger_cfg = make_config(super_conf=make_config(a=builds(int))) + out = zen(lambda super_conf: super_conf)(bigger_cfg) + assert out.a == 0 + + +class Pre: + record = [] + + +pre_call_strat = st.just(lambda cfg: Pre.record.append(cfg.x)) | st.just( + lambda cfg, optional=None: Pre.record.append(cfg.x) +) + + +@given( + pre_call=(pre_call_strat | st.lists(pre_call_strat)), +) +def test_multiple_pre_calls(pre_call): + Pre.record.clear() + cfg = make_config(x=1, y="a") + g = zen_identity.func + assert zen(pre_call=pre_call)(g)(cfg) == 1 + assert Pre.record == [1] * (len(pre_call) if isinstance(pre_call, list) else 1) + + +@pytest.mark.skipif( + sys.platform.startswith("win") and bool(os.environ.get("CI")), + reason="Things are weird on GitHub Actions and Windows", +) +@pytest.mark.usefixtures("cleandir") +def test_hydra_main(): + import subprocess + from pathlib import Path + + from hydra_zen import load_from_yaml + + path = (Path(__file__).parent / "dummy_zen_main.py").absolute() + assert not (Path.cwd() / "outputs").is_dir() + subprocess.run(["python", path, "x=1", "y=2"]).check_returncode() + assert (Path.cwd() / "outputs").is_dir() + + *_, latest_job = sorted((Path.cwd() / "outputs").glob("*/*")) + + assert load_from_yaml(latest_job / ".hydra" / "config.yaml") == { + "x": 1, + "y": 2, + "z": "${y}", + "seed": 12, + } + + +@pytest.mark.parametrize( + "zen_func", + [ + zen(lambda x, **kw: {"x": x, **kw}, unpack_kwargs=True), + zen(unpack_kwargs=True)(lambda x, **kw: {"x": x, **kw}), + ], +) +@given( + x=st.sampled_from(range(-3, 3)), + # non-string keys should be skipped by unpack_kw + kw=st.dictionaries( + st.sampled_from("abcdef") | st.sampled_from([1, 2]), st.integers() + ), +) +def test_unpack_kw_basic_behavior(zen_func, x, kw): + inp = dict(x=builds(int, x)) + inp.update(kw) + out = zen_func(inp) + expected = {"x": x, **{k: v for k, v in kw.items() if isinstance(k, str)}} + assert out == expected + + +def test_unpack_kw_non_redundant(): + x, y, kw = zen(lambda x, y=2, **kw: (x, y, kw), unpack_kwargs=True)( + dict(x=1, z="${x}") + ) + assert x == 1 + assert y == 2 + assert kw == {"z": 1} # x should not be in kw