From 0aa60e8275362890ebe434ae96da96537dcf5182 Mon Sep 17 00:00:00 2001 From: wejpurvis Date: Sun, 25 Feb 2024 12:53:12 +0000 Subject: [PATCH 1/2] Fixed mutable default value error in VelocityKernel dataclass Replaced direct assignments of RBF kernels in VelocityKernel dataclass with field(default_factory=...) oceanmodelling example now works with python 3.11 --- docs/examples/oceanmodelling.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/docs/examples/oceanmodelling.py b/docs/examples/oceanmodelling.py index f917ec04..ec3f0aad 100644 --- a/docs/examples/oceanmodelling.py +++ b/docs/examples/oceanmodelling.py @@ -195,10 +195,17 @@ def dataset_3d(pos, vel): # %% +from dataclasses import field + + @dataclass class VelocityKernel(gpx.kernels.AbstractKernel): - kernel0: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]) - kernel1: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]) + kernel0: gpx.kernels.AbstractKernel = field( + default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1]) + ) + kernel1: gpx.kernels.AbstractKernel = field( + default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1]) + ) def __call__( self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"] @@ -429,8 +436,12 @@ def plot_fields( @dataclass class HelmholtzKernel(gpx.kernels.AbstractKernel): # initialise Phi and Psi kernels as any stationary kernel in gpJax - potential_kernel: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]) - stream_kernel: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]) + potential_kernel: gpx.kernels.AbstractKernel = field( + default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1]) + ) + stream_kernel: gpx.kernels.AbstractKernel = field( + default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1]) + ) def __call__( self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"] From 027fec9678abd28deca15186e7d621bb7549670b Mon Sep 17 00:00:00 2001 From: wejpurvis Date: Sun, 25 Feb 2024 13:02:01 +0000 Subject: [PATCH 2/2] Fixed mutable default value error in kernel dataclasses Replaced direct assignments of kernels in VelocityKernel and HelmholtzKernel dataclasses with field(default_factory=...) oceanmodelling example now works with python 3.11 --- docs/examples/oceanmodelling.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/examples/oceanmodelling.py b/docs/examples/oceanmodelling.py index ec3f0aad..77f125b1 100644 --- a/docs/examples/oceanmodelling.py +++ b/docs/examples/oceanmodelling.py @@ -10,7 +10,7 @@ from jax import config config.update("jax_enable_x64", True) -from dataclasses import dataclass +from dataclasses import dataclass, field from jax import hessian from jax import config @@ -195,7 +195,6 @@ def dataset_3d(pos, vel): # %% -from dataclasses import field @dataclass