Skip to content

Commit

Permalink
Add replace to transform
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Aug 9, 2024
1 parent 5b63449 commit 1cc2962
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions gpjax/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,18 @@ def _inner(param):
else:
transformed_value = bijector.forward(param.value)

param.value = transformed_value
param.replace(transformed_value)

return param

transformed_params = jtu.tree_map(
gp_params, *other_params = params.split(Parameter, ...)

transformed_gp_params: nnx.State = jtu.tree_map(
lambda x: _inner(x),
params,
gp_params,
is_leaf=lambda x: isinstance(x, nnx.VariableState),
)
return transformed_params
return nnx.State.merge(transformed_gp_params, *other_params)


class Parameter(nnx.Variable[T]):
Expand Down

0 comments on commit 1cc2962

Please sign in to comment.