From 1cc2962717552915e01cf2941e5d0db417f43afb Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 9 Aug 2024 08:15:41 +0200 Subject: [PATCH] Add replace to transform --- gpjax/parameters.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 8ac9a318..63cce32b 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -56,16 +56,18 @@ def _inner(param): else: transformed_value = bijector.forward(param.value) - param.value = transformed_value + param.replace(transformed_value) return param - transformed_params = jtu.tree_map( + gp_params, *other_params = params.split(Parameter, ...) + + transformed_gp_params: nnx.State = jtu.tree_map( lambda x: _inner(x), - params, + gp_params, is_leaf=lambda x: isinstance(x, nnx.VariableState), ) - return transformed_params + return nnx.State.merge(transformed_gp_params, *other_params) class Parameter(nnx.Variable[T]):