Skip to content

Commit

Permalink
Fix loading of optimizer with older DQN models (#1978)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Jul 26, 2024
1 parent 000544c commit bd3c0c6
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 5 deletions.
11 changes: 10 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
Changelog
==========

Release 2.4.0a6 (WIP)
Release 2.4.0a7 (WIP)
--------------------------

.. note::

DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about
truncation of optimizer state when loaded with SB3 >= 2.4.0.
To suppress the warning, simply save the model again.
You can find more info in `PR #1963 <https://github.com/DLR-RM/stable-baselines3/pull/1963>`_

Breaking Changes:
^^^^^^^^^^^^^^^^^

Expand All @@ -28,9 +35,11 @@ Bug Fixes:

`RL Zoo`_
^^^^^^^^^
- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results)

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added CNN support for DQN

Deprecations:
^^^^^^^^^^^^^
Expand Down
27 changes: 25 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,13 +742,13 @@ def load( # noqa: C901
# put state_dicts back in place
model.set_parameters(params, exact_match=True, device=device)
except RuntimeError as e:
# Patch to load Policy saved using SB3 < 1.7.0
# Patch to load policies saved using SB3 < 1.7.0
# the error is probably due to old policy being loaded
# See https://github.com/DLR-RM/stable-baselines3/issues/1233
if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e):
model.set_parameters(params, exact_match=False, device=device)
warnings.warn(
"You are probably loading a model saved with SB3 < 1.7.0, "
"You are probably loading a A2C/PPO model saved with SB3 < 1.7.0, "
"we deactivated exact_match so you can save the model "
"again to avoid issues in the future "
"(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). "
Expand All @@ -757,6 +757,29 @@ def load( # noqa: C901
)
else:
raise e
except ValueError as e:
# Patch to load DQN policies saved using SB3 < 2.4.0
# The target network params are no longer in the optimizer
# See https://github.com/DLR-RM/stable-baselines3/pull/1963
saved_optim_params = params["policy.optimizer"]["param_groups"][0]["params"] # type: ignore[index]
n_params_saved = len(saved_optim_params)
n_params = len(model.policy.optimizer.param_groups[0]["params"])
if n_params_saved == 2 * n_params:
# Truncate to include only online network params
params["policy.optimizer"]["param_groups"][0]["params"] = saved_optim_params[:n_params] # type: ignore[index]

model.set_parameters(params, exact_match=True, device=device)
warnings.warn(
"You are probably loading a DQN model saved with SB3 < 2.4.0, "
"we truncated the optimizer state so you can save the model "
"again to avoid issues in the future "
"(see https://github.com/DLR-RM/stable-baselines3/pull/1963 for more info). "
f"Original error: {e} \n"
"Note: the model should still work fine, this only a warning."
)
else:
raise e

# put other pytorch variables back in place
if pytorch_variables is not None:
for name in pytorch_variables:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a6
2.4.0a7
14 changes: 13 additions & 1 deletion tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def test_save_load_env_cnn(tmp_path, model_class):
# clear file from os
os.remove(tmp_path / "test_save.zip")

# Check we can load models saved with SB3 < 1.7.0
# Check we can load A2C/PPO models saved with SB3 < 1.7.0
if model_class == A2C:
del model.policy.pi_features_extractor
model.save(tmp_path / "test_save")
Expand Down Expand Up @@ -809,3 +809,15 @@ def test_save_load_net_arch_none(tmp_path):
# None has been replaced by the default net arch
assert model.policy.net_arch is not None
os.remove(tmp_path / "ppo.zip")


def test_save_load_no_target_params(tmp_path):
# Check we can load DQN models saved with SB3 < 2.4.0
model = DQN("MlpPolicy", "CartPole-v1", buffer_size=10000, learning_starts=4)
env = model.get_env()
# Include target net params
model.policy.optimizer = th.optim.Adam(model.policy.parameters(), lr=0.001)
model.save(tmp_path / "test_save")
with pytest.warns(UserWarning):
DQN.load(str(tmp_path / "test_save.zip"), env=env).learn(20)
os.remove(tmp_path / "test_save.zip")

0 comments on commit bd3c0c6

Please sign in to comment.