-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds @zen
decorator for task functions
#310
Conversation
I always like the idea of decoupling the application logic from the configuration framework (à la Bob Martin's dictum that frameworks should be kept at arms length). This With complex applications, it might be difficult to migrate from One possible pattern to make gradual migration easier would be for from hydra_zen import builds, instantiate, make_config, zen
class Foo:
def __init__(self, x: str) -> None:
pass
Cfg = make_config(seed=1, foo=builds(Foo, x="bar"), unused=[1, 2])
# gradually migrate `old_task_fn` to `new_task_fn`:
def old_task_fn(cfg):
seed = cfg.seed
foo = instantiate(cfg.foo)
print(seed, foo)
@zen
def new_task_fn_v0(zen_cfg): # special `zen_cfg` keyword passes unmodified config
cfg = zen_cfg
seed = cfg.seed
foo = instantiate(cfg.foo)
print(seed, foo)
@zen
def new_task_fn_v1(seed, zen_cfg):
cfg = zen_cfg
foo = instantiate(cfg.foo)
print(seed, foo)
@zen
def new_task_fn_v2(seed, foo, zen_cfg):
cfg = zen_cfg
print(seed, foo)
@zen
def new_task_fn_final(seed, foo):
print(seed, foo)
# all of the below are equivalent:
old_task_fn(Cfg)
new_task_fn_v0(Cfg)
new_task_fn_v1(Cfg)
new_task_fn_v2(Cfg)
new_task_fn_final(Cfg) $ # Below is the diff I used to accomplish this special treatment of the `zen_cfg` keyword argument
$ git diff
diff --git a/src/hydra_zen/_zen.py b/src/hydra_zen/_zen.py
index 5ba7e99..ecff97f 100644
--- a/src/hydra_zen/_zen.py
+++ b/src/hydra_zen/_zen.py
@@ -169,15 +169,20 @@ class Zen(Generic[P, T1]):
else getattr(cfg, name)
)
for name, param in self.parameters.items()
+ if name != "zen_cfg"
if param.kind not in SKIPPED_PARAM_KINDS
}
+ kwargs_final = {
+ name: instantiate(val) if is_instantiable(val) else val
+ for name, val in cfg_kwargs.items()
+ }
+ if "zen_cfg" in self.parameters:
+ kwargs_final["zen_cfg"] = cfg
+
out = self.func(
*(instantiate(x) if is_instantiable(x) else x for x in args_),
- **{
- name: instantiate(val) if is_instantiable(val) else val
- for name, val in cfg_kwargs.items()
- },
+ **kwargs_final,
) # type: ignore
return out |
Thanks for this, @Jasha10 ! I had been thinking of including an "escape hatch" like the one you sketched out with The one thing that I'll need to take care if is to make clear to the user that Hydra does not perform singleton instantiation. E.g. they might be tempted to believe that the following will hold: @zen
def func(model, zen_cfg):
assert model is instantiate(zen_cfg.model) # <- this will fail! (there are other ways, e.g. via interpolation, where this confusion could manifest, but those are independent of
This is sage advice! I have developed an intuition for this, but I never could have put it so succinctly. I will be sure to read this blog post 😄 |
This PR introduces the
zen
decorator, which changes the interface of an arbitrary (Hydra-agnostic) function so that it can accept a Hydra config as its input. The decorator inspects a inner-function's signature and extracts (+resolves and instantiates) the appropriate fields from an input config to call said function:Using
@zen
to improve task functions@zen
is designed to help decouple one's task function from the Hydra framework. By doing so, it improves the task function's legibility, versatility, and testability.Given:
One typically writes task functions as so:
There are several issues with this:
cfg
– is relatively opaque to users and type-checkers alike.old_task_fn
is tightly coupled to the Hydra framework – you must pass it a config to run it@zen
strives to rectify all of these shortcomings. Let's use it to refactor our task function:Here,
@zen
makesnew_task_fn
explicit, legible, and boilerplate free. It works by inspecting the signature ofnew_task_fn
and extracting and instantiating the corresponding parameters from our config12. Given the explicit signature (with optional annotations), users and IDEs can easily understand the context of the task function's body.Furthermore, one can run the underlying task function, via
.func
, independently of a Hydra app:Given this accessibility, and because our task-function is now free of Hydra-specific boilerplate code, we can easily use/test our task function outside of the context of our Hydra app.
zen
makes it trivial to take any 3rd party function and transform it into a Hydra-compatible task function:Using
@zen
instead of@hydra.main
The object returned by
zen
provides a convenience method --Zen.hydra_main
-- so that users need not double-wrap with@hydra.main
to create a CLI:Additional Bells & Whistles
Validation
A
zen
-wrapped function can validate configs without calling the function itself. This makes it easy to test compatibility between your task functions and configs (e.g., as part of your CI/CD process)Customizing the Wrapper Behavior
One can subclass
hydra_zen.wrapper.Zen
and pass it to@zen
to modify the wrapped behavior.In the following example we add the ability to log the config (as a yaml) upon each call of a zen-wrapped function.
Adding a Pre-Call Step
Recall that
@zen
will automatically instantiate a sub-config prior to passing it to the decorated function. If that instantiated object relies on random behavior, it can be useful to be able to set a seed prior to the instantiation process. We can do this via@zen(pre_call=...)
:Without pre-call:
With pre-call:
Validation propagates through zen-wrapped pre-call functions:
Passing Through The Config
Some task functions require complete access to the full config to gain access to sub-configs. One can specify the field named
zen_config
3 in their task function's signature to signalzen
that it should pass the full config to that parameter .Footnotes
@zen
only performs instantiation on extracted fields as-needed. Thus it avoids accessing/instantiating parts of a larger config that are not necessary for the given task function. ↩Interpolated fields are resolved by calls mediated through
zen
↩You can change this specialized name by subclassing
Zen
↩