Skip to content

Commit

Permalink
Merge pull request #464 from JaxGaussianProcesses/fix-transform
Browse files Browse the repository at this point in the history
Fix transform
  • Loading branch information
thomaspinder authored Aug 14, 2024
2 parents 5b63449 + 1cc2962 commit 684cd8e
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions gpjax/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,18 @@ def _inner(param):
else:
transformed_value = bijector.forward(param.value)

param.value = transformed_value
param.replace(transformed_value)

return param

transformed_params = jtu.tree_map(
gp_params, *other_params = params.split(Parameter, ...)

transformed_gp_params: nnx.State = jtu.tree_map(
lambda x: _inner(x),
params,
gp_params,
is_leaf=lambda x: isinstance(x, nnx.VariableState),
)
return transformed_params
return nnx.State.merge(transformed_gp_params, *other_params)


class Parameter(nnx.Variable[T]):
Expand Down

0 comments on commit 684cd8e

Please sign in to comment.