Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed mutable default value error in VelocityKernel and HelmholtzKernel dataclasses in python 3.11 #439

Merged
merged 2 commits into from
Feb 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions docs/examples/oceanmodelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax import config

config.update("jax_enable_x64", True)
from dataclasses import dataclass
from dataclasses import dataclass, field

from jax import hessian
from jax import config
Expand Down Expand Up @@ -195,10 +195,16 @@ def dataset_3d(pos, vel):


# %%


@dataclass
class VelocityKernel(gpx.kernels.AbstractKernel):
kernel0: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1])
kernel1: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1])
kernel0: gpx.kernels.AbstractKernel = field(
default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1])
)
kernel1: gpx.kernels.AbstractKernel = field(
default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1])
)

def __call__(
self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"]
Expand Down Expand Up @@ -429,8 +435,12 @@ def plot_fields(
@dataclass
class HelmholtzKernel(gpx.kernels.AbstractKernel):
# initialise Phi and Psi kernels as any stationary kernel in gpJax
potential_kernel: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1])
stream_kernel: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1])
potential_kernel: gpx.kernels.AbstractKernel = field(
default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1])
)
stream_kernel: gpx.kernels.AbstractKernel = field(
default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1])
)

def __call__(
self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"]
Expand Down
Loading