Skip to content

Commit

Permalink
Remove unnecessary SDE resampling in PPO update (#1933)
Browse files Browse the repository at this point in the history
* Remove unnecessary SDE resampling in PPO update

* Update changelog.rst

* Update version

* Update PyTorch version on CI

* Update ruff

* Limit NumPy version

* Reformat

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
brn-dev and araffin authored Jun 29, 2024
1 parent 4efee92 commit 24ebf1a
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
# Install Atari Roms
pip install autorom
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ lint:
# see https://www.flake8rules.com/
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff check ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero --output-format=concise

format:
# Sort imports
Expand Down
6 changes: 5 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.4.0a3 (WIP)
Release 2.4.0a4 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -19,6 +19,7 @@ Bug Fixes:
- Cast type in compute gae method to avoid error when using torch compile (@amjames)
- ``CallbackList`` now sets the ``.parent`` attribute of child callbacks to its own ``.parent``. (will-maclean)
- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122)
- Set requirement numpy<2.0 until PyTorch is compatible (https://github.com/pytorch/pytorch/issues/107302)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand All @@ -35,6 +36,8 @@ Deprecations:
Others:
^^^^^^^
- Fixed various typos (@cschindlbeck)
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
- Updated PyTorch version on CI to 2.3.1

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -1664,3 +1667,4 @@ And all the contributors:
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
@brn-dev
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gymnasium>=0.28.1,<0.30",
"numpy>=1.20",
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
"torch>=1.13",
# For saving models
"cloudpickle",
Expand Down
4 changes: 0 additions & 4 deletions stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,6 @@ def train(self) -> None:
# Convert discrete action from float to long
actions = rollout_data.actions.long().flatten()

# Re-sample the noise matrix because the log_std has changed
if self.use_sde:
self.policy.reset_noise(self.batch_size)

values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
values = values.flatten()
# Normalize advantage
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a3
2.4.0a4
4 changes: 2 additions & 2 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,11 +791,11 @@ def test_cast_lr_schedule(tmp_path):
# Note: for recent version of numpy, np.float64 is a subclass of float
# so we need to use type here
# assert isinstance(model.lr_schedule(1.0), float)
assert type(model.lr_schedule(1.0)) is float # noqa: E721
assert type(model.lr_schedule(1.0)) is float
assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0))
model.save(tmp_path / "ppo.zip")
model = PPO.load(tmp_path / "ppo.zip")
assert type(model.lr_schedule(1.0)) is float # noqa: E721
assert type(model.lr_schedule(1.0)) is float
assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0))


Expand Down

0 comments on commit 24ebf1a

Please sign in to comment.