Skip to content

Commit

Permalink
Cast learning_rate to float lambda for pickle safety when doing model…
Browse files Browse the repository at this point in the history
….load (DLR-RM#1901)

* create failing test for unpickle error

* Fix learning_rate argument causing failure in weights_only=True if passed a function with non-float types

* Updated with feedback from araffin on PR#1901

* Update test and version

* Update changelog and SBX doc

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
  • Loading branch information
2 people authored and friedeggs committed Jul 22, 2024
1 parent e8f3989 commit 89694c5
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ Documentation is available online: [https://sb3-contrib.readthedocs.io/](https:/

## Stable-Baselines Jax (SBX)

[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax.
[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax, with recent algorithms like DroQ or CrossQ.

It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!): https://twitter.com/araffin2/status/1590714558628253698

Expand Down
3 changes: 2 additions & 1 deletion docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ Actions ``gym.spaces``:

.. note::

More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo <sb3_contrib>`.
More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo <sb3_contrib>`
and in our :ref:`SBX (SB3 + Jax) repo <sbx>` (DroQ, CrossQ, ...).

.. note::

Expand Down
15 changes: 9 additions & 6 deletions docs/guide/sbx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Implemented algorithms:
- Deep Q Network (DQN)
- Twin Delayed DDPG (TD3)
- Deep Deterministic Policy Gradient (DDPG)
- Batch Normalization in Deep Reinforcement Learning (CrossQ)


As SBX follows SB3 API, it is also compatible with the `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_.
Expand All @@ -29,16 +30,17 @@ For that you will need to create two files:
import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.ALGOS["droq"] = DroQ
# See SBX readme to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
Expand All @@ -56,16 +58,17 @@ Then you can call ``python train_sbx.py --algo sac --env Pendulum-v1`` and use t
import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.ALGOS["droq"] = DroQ
# See SBX readme to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
Expand Down
14 changes: 13 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
Changelog
==========

Release 2.3.1 (2024-04-22)
--------------------------

Bug Fixes:
^^^^^^^^^^
- Cast return value of learning rate schedule to float, to avoid issue when loading model because of ``weights_only=True`` (@markscsmith)

Documentation:
^^^^^^^^^^^^^^
- Updated SBX documentation (CrossQ and deprecated DroQ)


Release 2.3.0 (2024-03-31)
--------------------------

Expand Down Expand Up @@ -1593,4 +1605,4 @@ And all the contributors:
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah
@marekm4 @stagoverflow @rushitnshah @markscsmith
4 changes: 3 additions & 1 deletion stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule:
value_schedule = constant_fn(float(value_schedule))
else:
assert callable(value_schedule)
return value_schedule
# Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900
# Some types are have odd behaviors when part of a Schedule, like numpy floats
return lambda progress_remaining: float(value_schedule(progress_remaining))


def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.3.0
2.3.1
14 changes: 14 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,3 +783,17 @@ def test_no_resource_warning(tmp_path):
fp.seek(0)
model.load_replay_buffer(fp)
assert not fp.closed


def test_cast_lr_schedule(tmp_path):
# See GH#1900
model = PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda t: t * np.sin(1.0))
# Note: for recent version of numpy, np.float64 is a subclass of float
# so we need to use type here
# assert isinstance(model.lr_schedule(1.0), float)
assert type(model.lr_schedule(1.0)) is float # noqa: E721
assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0))
model.save(tmp_path / "ppo.zip")
model = PPO.load(tmp_path / "ppo.zip")
assert type(model.lr_schedule(1.0)) is float # noqa: E721
assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0))

0 comments on commit 89694c5

Please sign in to comment.