Skip to content

Commit

Permalink
Merge pull request #264 from JaxGaussianProcesses/henry/quickfix/inst…
Browse files Browse the repository at this point in the history
…ructions

Tiny doc improvement
  • Loading branch information
thomaspinder authored May 16, 2023
2 parents 44ae92b + 2ba979b commit 655a00d
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 22 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: Build the documentation

on:
pull_request:
branches:
- main
push:
branches:
- main
Expand Down
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,6 @@ pip install gpjax
>
> This version is possibly unstable and may contain bugs.
Clone a copy of the repository to your local machine and run the setup
configuration in development mode.
```bash
git clone https://github.com/JaxGaussianProcesses/GPJax.git
cd GPJax
poetry install
```
> **Note**
>
> We advise you create virtual environment before installing:
Expand All @@ -189,6 +181,14 @@ poetry install
> poetry run pytest
> ```
Clone a copy of the repository to your local machine and run the setup
configuration in development mode.
```bash
git clone https://github.com/JaxGaussianProcesses/GPJax.git
cd GPJax
poetry install
```
# Citing GPJax

If you use GPJax in your research, please cite our [JOSS paper](https://joss.theoj.org/papers/10.21105/joss.04455#).
Expand Down
6 changes: 3 additions & 3 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ pip install gpjax

## GPU/TPU support

Fancy using GPJax on GPU/TPU? Then you'll need to install JAX with the relevant
hardware acceleration support as detailed in the
Fancy using GPJax on GPU/TPU? Then you'll need to install JAX with the relevant
hardware acceleration support as detailed in the
[JAX installation guide](https://github.com/google/jax#installation).


Expand Down Expand Up @@ -46,4 +46,4 @@ hardware acceleration support as detailed in the

```bash
poetry run pytest tests/
```
```
2 changes: 1 addition & 1 deletion docs/stylesheets/permalinks.css
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@
order: -1;
margin-left: calc(var(--permalink-size) * -1 - var(--permalink-spacing)) !important;
}
}
}
2 changes: 1 addition & 1 deletion gpjax/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
meta_flatten,
meta_leaves,
meta_map,
static_field,
save_tree,
static_field,
)
from gpjax.base.param import param_field

Expand Down
4 changes: 1 addition & 3 deletions tests/test_base/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import pytest
from simple_pytree import (
Pytree,
)
from simple_pytree import Pytree
import tensorflow_probability.substrates.jax.bijectors as tfb

from gpjax.base.module import (
Expand Down
7 changes: 5 additions & 2 deletions tests/test_kernels/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass, field
from dataclasses import (
dataclass,
field,
)

from jax.config import config
import jax.numpy as jnp
Expand Down Expand Up @@ -55,7 +58,7 @@ def test_abstract_kernel():
# Create a dummy kernel class with __call__ implemented:
@dataclass
class DummyKernel(AbstractKernel):
test_a: Float[Array, "1"] = field(default_factory = lambda: jnp.array([1.0]))
test_a: Float[Array, "1"] = field(default_factory=lambda: jnp.array([1.0]))
test_b: Float[Array, "1"] = param_field(
jnp.array([2.0]), bijector=tfb.Softplus()
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_linops/test_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import pytest
from gpjax.base import static_field

from gpjax.base import static_field
from gpjax.linops.linear_operator import LinearOperator


Expand Down

0 comments on commit 655a00d

Please sign in to comment.