Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix config import #410

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/collapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/deep_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/likelihoods_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# +
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
4 changes: 2 additions & 2 deletions docs/examples/oceanmodelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
#
# Surface drifters are measurement devices that measure the dynamics and circulation patterns of the world's oceans. Studying and predicting ocean currents are important to climate research, for example, forecasting and predicting oil spills, oceanographic surveying of eddies and upwelling, or providing information on the distribution of biomass in ecosystems. We will be using the [Gulf Drifters Open dataset](https://zenodo.org/record/4421585), which contains all publicly available surface drifter trajectories from the Gulf of Mexico spanning 28 years.
# %%
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)
from dataclasses import dataclass

from jax import hessian
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp
from jax.config import config
from jax import config
from jaxtyping import install_import_hook

with install_import_hook("gpjax", "beartype.beartype"):
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/regression_mo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#
# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/uncollapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/yacht.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_citations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
except ImportError:
ValidationErrors = ValueError

from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from dataclasses import dataclass

from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gaussian_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# ==============================================================================


from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import pytest
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
Type,
)

from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import typing as tp

import jax
from jax.config import config
from jax import config
import jax.numpy as jnp
import numpy as np
import pytest
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernels/test_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from cola.ops import Dense
import jax
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import pytest
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernels/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
field,
)

from jax.config import config
from jax import config
import jax.numpy as jnp
from jaxtyping import (
Array,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernels/test_non_euclidean.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# # limitations under the License.

from cola.ops import I_like
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import networkx as nx
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernels/test_nonstationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from cola.ops import LinearOperator
import jax
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernels/test_stationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from cola.ops import LinearOperator
import jax
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.tree_util as jtu
import pytest
Expand Down
2 changes: 1 addition & 1 deletion tests/test_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
List,
)

from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
Expand Down
2 changes: 1 addition & 1 deletion tests/test_objectives.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import jax
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import pytest
Expand Down
2 changes: 1 addition & 1 deletion tests/test_variational_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Tuple,
)

from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
Expand Down