Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update w/ Dan comments #130

Merged
merged 1 commit into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def find_version(*file_paths):
bibtex_style = "unsrt"
bibtex_reference_style = "author_year"
nb_execution_mode = "auto"
nbsphinx_allow_errors = True
nbsphinx_allow_errors = False
nbsphinx_custom_formats = {
".pct.py": ["jupytext.reads", {"fmt": "py:percent"}],
}
jupyter_execute_notebooks = "cache"
nbsphinx_execute_arguments = ["--InlineBackend.figure_formats={'svg', 'pdf'}"]

# Latex commands
# mathjax_path = "https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"
Expand Down Expand Up @@ -206,7 +206,6 @@ def find_version(*file_paths):
# }



# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
Expand Down
7 changes: 3 additions & 4 deletions docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ many ways to contibute, including:
- Fixing outstanding [issues](https://github.com/thomaspinder/GPJax/issues)
(bugs).
- Extending or improving our [codebase](https://github.com/thomaspinder/GPJax).
- Submitting issues related to bugs or desired enhancements.


# Code of conduct
Expand Down Expand Up @@ -44,7 +43,7 @@ install our `pre-commit hooks`, `commit` and `push` your code.
you through every detail!

:::{attention} Before opening a pull request we recommend you check our [pull
request checklist](#pull-request-checklist). :::
request checklist](#pull-request-checklist).


## Step-by-step guide:
Expand All @@ -69,7 +68,7 @@ request checklist](#pull-request-checklist). :::
```

:::{attention} Always use a `feature` branch. It's good practice to avoid
work on the ``main`` branch of any repository. :::
work on the ``main`` branch of any repository.

4. Project requirements are in ``requirements.txt``. We suggest using a
[virtual environment](https://docs.python-guide.org/dev/virtualenvs/) for
Expand All @@ -87,7 +86,7 @@ request checklist](#pull-request-checklist). :::
```
:::{warning} Please ensure you have done this before commiting any files. If
successful, this will print the following output `pre-commit installed at
.git/hooks/pre-commit`. :::
.git/hooks/pre-commit`.

6. Add changed files using `git add` and then `git commit` files to record your
changes locally:
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ nbsphinx>=0.8.0
nb-black==1.0.7
matplotlib==3.3.3
tensorflow-probability>=0.16.0
seaborn
sphinx-copybutton
networkx>=2.0.0
pandoc
Expand Down
13 changes: 7 additions & 6 deletions examples/tfp_integration.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,19 @@

# %%
import tensorflow_probability.substrates.jax as tfp
import tensorflow_probability.substrates.jax.bijectors as tfb

tfd = tfp.distributions

priors = gpx.parameters.copy_dict_structure(params)
priors["kernel"]["lengthscale"] = tfd.Gamma(
concentration=jnp.array(1.0), rate=jnp.array(1.0)
priors["kernel"]["lengthscale"] = tfd.TransformedDistribution(
tfd.Gamma(concentration=jnp.array(1.0), rate=jnp.array(1.0)), tfb.Softplus()
)
priors["kernel"]["variance"] = tfd.Gamma(
concentration=jnp.array(1.0), rate=jnp.array(1.0)
priors["kernel"]["variance"] = tfd.TransformedDistribution(
tfd.Gamma(concentration=jnp.array(1.0), rate=jnp.array(1.0)), tfb.Softplus()
)
priors["likelihood"]["obs_noise"] = tfd.Gamma(
concentration=jnp.array(1.0), rate=jnp.array(1.0)
priors["likelihood"]["obs_noise"] = tfd.TransformedDistribution(
tfd.Gamma(concentration=jnp.array(1.0), rate=jnp.array(1.0)), tfb.Softplus()
)

# %% [markdown]
Expand Down
2 changes: 1 addition & 1 deletion gpjax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

__config = None

FillTriangular = dx.Chain([tfb.FillTriangular()])
FillTriangular = dx.Chain([tfb.FillTriangular(), tfb.Softplus()])
Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x)
Softplus = dx.Lambda(
forward=lambda x: jnp.log(1 + jnp.exp(x)),
Expand Down