Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvement in log_det #354

Closed
wants to merge 1 commit into from
Closed

Conversation

frazane
Copy link
Contributor

@frazane frazane commented Aug 9, 2023

Type of changes

  • Bug fix
  • New feature
  • Documentation / docstrings
  • Tests
  • Performance

Checklist

  • I've formatted the new code by running poetry run pre-commit run --all-files --show-diff-on-failure before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

Performance improvement in the computation of log_det from LinearOperator.

Edit: unfortunately computing the product before the log can also result in overflow. For small matrices this is not an issue but it is for large ones.

import timeit

import jax
import jax.numpy as jnp
import jax.random as jr
from gpjax.linops import DenseLinearOperator

jax.config.update("jax_enable_x64", True)

key = jr.PRNGKey(123)
cov = jr.normal(key, (100, 100)).astype(jnp.float64)
cov = cov @ cov.T

linop = DenseLinearOperator(cov)
root = linop.to_root()
expr1 = "jnp.sum(jnp.log(root.diagonal()))"
res1 = jnp.sum(jnp.log(root.diagonal()))
time1 = timeit.timeit(expr1, number=1000, globals=globals())

expr2 = "jnp.log(jnp.prod(root.diagonal()))"
res2 = jnp.log(jnp.prod(root.diagonal()))
time2 = timeit.timeit(expr2, number=1000, globals=globals())

print(f"'{expr1}': {time1:.3}s")
print(f"'{expr2}': {time2:.3}s")
assert res1 == res2
'jnp.sum(jnp.log(root.diagonal()))': 0.0363s
'jnp.log(jnp.prod(root.diagonal()))': 0.0178s

@frazane frazane added the performance Performance label Aug 9, 2023
@frazane frazane changed the title logdet efficiency Performance improvement in log_det Aug 9, 2023
@frazane frazane marked this pull request as draft August 9, 2023 08:48
@thomaspinder
Copy link
Collaborator

Is this ready for merging @frazane ?

@frazane
Copy link
Contributor Author

frazane commented Aug 10, 2023

@thomaspinder unfortunately I realized there's an issue with numerical stability: floating point overflow occurs for very large matrices (large n), because jnp.prod(root.diagonal()) produces very large numbers.

I would not merge this PR for now. Maybe we could think of some conditional computation based on the size of the matrix. Or do some split-apply-combine kind of thing but I fear this could defeat the purpose of improving performance.

@daniel-dodd daniel-dodd closed this Sep 6, 2023
@frazane frazane deleted the logdet-efficiency branch November 15, 2023 08:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants