Skip to content

Commit

Permalink
Get/set parameters and review of saving and loading (#138)
Browse files Browse the repository at this point in the history
* Update comments and docstrings

* Rename get_torch_variables to private and update docs

* Clarify documentation on data, params and tensors

* Make excluded_save_params private and update docs

* Update get_torch_variable_names to get_torch_save_params for description

* Simplify saving code and update docs on params vs tensors

* Rename saved item tensors to pytorch_variables for clarity

* Reformat

* Fix a typo

* Add get/set_parameters, update tests accordingly

* Use f-strings for formatting

* Fix load docstring

* Reorganize functions in BaseClass

* Update changelog

* Add library version to the stored models

* Actually run isort this time

* Fix flake8 complaints and also fix testing code

* Fix isort

* ...and black

* Fix set_random_seed

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
  • Loading branch information
3 people authored Sep 24, 2020
1 parent 00595b0 commit 9855486
Show file tree
Hide file tree
Showing 8 changed files with 420 additions and 290 deletions.
10 changes: 10 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ New Features:
- Added ``StopTrainingOnMaxEpisodes`` to callback collection (@xicocaio)
- Added ``device`` keyword argument to ``BaseAlgorithm.load()`` (@liorcohen5)
- Callbacks have access to rollout collection locals as in SB2. (@PartiallyTyped)
- Added ``get_parameters`` and ``set_parameters`` for accessing/setting parameters of the agent
- Added actor/critic loss logging for TD3. (@mloo3)

Bug Fixes:
^^^^^^^^^^
- Fixed a bug where the environment was reset twice when using ``evaluate_policy``
- Fix logging of ``clip_fraction`` in PPO (@diditforlulz273)
- Fixed a bug where cuda support was wrongly checked when passing the GPU index, e.g., ``device="cuda:0"`` (@liorcohen5)
- Fixed a bug when the random seed was not properly set on cuda when passing the GPU index

Deprecations:
^^^^^^^^^^^^^
Expand All @@ -33,6 +35,14 @@ Others:
- Fix type annotation of ``make_vec_env`` (@ManifoldFR)
- Removed ``AlreadySteppingError`` and ``NotSteppingError`` that were not used
- Fixed typos in SAC and TD3
- Rename ``BaseClass.get_torch_variables`` -> ``BaseClass._get_torch_save_params`` and
``BaseClass.excluded_save_params`` -> ``BaseClass._excluded_save_params``
- Reorganized functions for clarity in ``BaseClass`` (save/load functions close to each other, private
functions at top)
- Clarified docstrings on what is saved and loaded to/from files
- Renamed saved items ``tensors`` to ``pytorch_variables`` for clarity
- Simplified ``save_to_zip_file`` function by removing duplicate code
- Store library version along with the saved models

Documentation:
^^^^^^^^^^^^^^
Expand Down
472 changes: 285 additions & 187 deletions stable_baselines3/common/base_class.py

Large diffs are not rendered by default.

5 changes: 1 addition & 4 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,7 @@ def learn(

return self

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]

return state_dicts, []
72 changes: 34 additions & 38 deletions stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import cloudpickle
import torch as th

import stable_baselines3
from stable_baselines3.common.type_aliases import TensorDict
from stable_baselines3.common.utils import get_device

Expand Down Expand Up @@ -284,21 +285,20 @@ def save_to_zip_file(
save_path: Union[str, pathlib.Path, io.BufferedIOBase],
data: Dict[str, Any] = None,
params: Dict[str, Any] = None,
tensors: Dict[str, Any] = None,
pytorch_variables: Dict[str, Any] = None,
verbose=0,
) -> None:
"""
Save a model to a zip archive.
Save model data to a zip archive.
:param save_path: (Union[str, pathlib.Path, io.BufferedIOBase]) Where to store the model.
if save_path is a str or pathlib.Path ensures that the path actually exists.
:param data: Class parameters being stored.
:param data: Class parameters being stored (non-PyTorch variables)
:param params: Model parameters being stored expected to contain an entry for every
state_dict with its name and the state_dict.
:param tensors: Extra tensor variables expected to contain name and value of tensors
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information
"""

save_path = open_path(save_path, "w", verbose=0, suffix="zip")
# data/params can be None, so do not
# try to serialize them blindly
Expand All @@ -310,13 +310,15 @@ def save_to_zip_file(
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
if tensors is not None:
with archive.open("tensors.pth", mode="w") as tensors_file:
th.save(tensors, tensors_file)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w") as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w") as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", stable_baselines3.__version__)


def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj, verbose=0) -> None:
Expand Down Expand Up @@ -362,8 +364,8 @@ def load_from_zip_file(
:param load_data: Whether we should load and return data
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
:param device: (Union[th.device, str]) Device on which the code should run.
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
and dict of extra tensors
:return: (dict),(dict),(dict) Class parameters, model state_dicts (aka "params", dict of state_dict)
and dict of pytorch variables
"""
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")

Expand All @@ -378,44 +380,38 @@ def load_from_zip_file(
# zip archive, assume they were stored
# as None (_save_to_file_zip allows this).
data = None
tensors = None
pytorch_variables = None
params = {}

if "data" in namelist and load_data:
# Load class parameters and convert to string
# Load class parameters that are stored
# with either JSON or pickle (not PyTorch variables).
json_data = archive.read("data").decode()
data = json_to_data(json_data)

if "tensors.pth" in namelist and load_data:
# Load extra tensors
with archive.open("tensors.pth", mode="r") as tensor_file:
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
# Check for all .pth files and load them using th.load.
# "pytorch_variables.pth" stores PyTorch variables, and any other .pth
# files store state_dicts of variables with custom names (e.g. policy, policy.optimizer)
pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth"]
for file_path in pth_files:
with archive.open(file_path, mode="r") as param_file:
# File has to be seekable, but param_file is not, so load in BytesIO first
# fixed in python >= 3.7
file_content = io.BytesIO()
file_content.write(tensor_file.read())
file_content.write(param_file.read())
# go to start of file
file_content.seek(0)
# load the parameters with the right ``map_location``
tensors = th.load(file_content, map_location=device)

# check for all other .pth files
other_files = [
file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"
]
# if there are any other files which end with .pth and aren't "params.pth"
# assume that they each are optimizer parameters
if len(other_files) > 0:
for file_path in other_files:
with archive.open(file_path, mode="r") as opt_param_file:
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
# fixed in python >= 3.7
file_content = io.BytesIO()
file_content.write(opt_param_file.read())
# go to start of file
file_content.seek(0)
# load the parameters with the right ``map_location``
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
# Load the parameters with the right ``map_location``.
# Remove ".pth" ending with splitext
th_object = th.load(file_content, map_location=device)
if file_path == "pytorch_variables.pth":
# PyTorch variables (not state_dicts)
pytorch_variables = th_object
else:
# State dicts. Store into params dictionary
# with same name as in .zip file (without .pth)
params[os.path.splitext(file_path)[0]] = th_object
except zipfile.BadZipFile:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
return data, params, tensors
return data, params, pytorch_variables
16 changes: 3 additions & 13 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,20 +231,10 @@ def learn(
reset_num_timesteps=reset_num_timesteps,
)

def excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.
:return: (List[str]) List of parameters that should be excluded from save
"""
# Exclude aliases
return super(DQN, self).excluded_save_params() + ["q_net", "q_net_target"]
def _excluded_save_params(self) -> List[str]:
return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"]

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]

return state_dicts, []
24 changes: 7 additions & 17 deletions stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,24 +293,14 @@ def learn(
reset_num_timesteps=reset_num_timesteps,
)

def excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.
:return: (List[str]) List of parameters that should be excluded from save
"""
# Exclude aliases
return super(SAC, self).excluded_save_params() + ["actor", "critic", "critic_target"]

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
def _excluded_save_params(self) -> List[str]:
return super(SAC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
saved_tensors = ["log_ent_coef"]
saved_pytorch_variables = ["log_ent_coef"]
if self.ent_coef_optimizer is not None:
state_dicts.append("ent_coef_optimizer")
else:
saved_tensors.append("ent_coef_tensor")
return state_dicts, saved_tensors
saved_pytorch_variables.append("ent_coef_tensor")
return state_dicts, saved_pytorch_variables
18 changes: 4 additions & 14 deletions stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,19 +205,9 @@ def learn(
reset_num_timesteps=reset_num_timesteps,
)

def excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.
:return: (List[str]) List of parameters that should be excluded from save
"""
# Exclude aliases
return super(TD3, self).excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
def _excluded_save_params(self) -> List[str]:
return super(TD3, self)._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
return state_dicts, []
93 changes: 76 additions & 17 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pathlib
import warnings
from collections import OrderedDict
from copy import deepcopy

import gym
Expand Down Expand Up @@ -33,7 +34,7 @@ def select_env(model_class: BaseAlgorithm) -> gym.Env:
def test_save_load(tmp_path, model_class):
"""
Test if 'save' and 'load' saves and loads model correctly
and if 'load_parameters' and 'get_policy_parameters' work correctly
and if 'get_parameters' and 'set_parameters' and work correctly.
''warning does not test function of optimizer parameter load
Expand All @@ -49,19 +50,73 @@ def test_save_load(tmp_path, model_class):
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)

# Get dictionary of current parameters
params = deepcopy(model.policy.state_dict())
# Get parameters of different objects
# deepcopy to avoid referencing to tensors we are about to modify
original_params = deepcopy(model.get_parameters())

# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
# Test different error cases of set_parameters.
# Test that invalid object names throw errors
invalid_object_params = deepcopy(original_params)
invalid_object_params["I_should_not_be_a_valid_object"] = "and_I_am_an_invalid_tensor"
with pytest.raises(ValueError):
model.set_parameters(invalid_object_params, exact_match=True)
with pytest.raises(ValueError):
model.set_parameters(invalid_object_params, exact_match=False)

# Update model parameters with the new random values
model.policy.load_state_dict(random_params)
# Test that exact_match catches when something was missed.
missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1])
with pytest.raises(ValueError):
model.set_parameters(missing_object_params, exact_match=True)

# Test that exact_match catches when something inside state-dict
# is missing but we have exact_match.
missing_state_dict_tensor_params = {}
for object_name in original_params:
object_params = {}
missing_state_dict_tensor_params[object_name] = object_params
# Skip last item in state-dict
for k, v in list(original_params[object_name].items())[:-1]:
object_params[k] = v
with pytest.raises(RuntimeError):
# PyTorch load_state_dict throws RuntimeError if strict but
# invalid state-dict.
model.set_parameters(missing_state_dict_tensor_params, exact_match=True)

# Test that parameters do indeed change.
random_params = {}
for object_name, params in original_params.items():
# Do not randomize optimizer parameters (custom layout)
if "optim" in object_name:
random_params[object_name] = params
else:
# Again, skip the last item in state-dict
random_params[object_name] = OrderedDict(
(param_name, th.rand_like(param)) for param_name, param in list(params.items())[:-1]
)

new_params = model.policy.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
# Update model parameters with the new random values
model.set_parameters(random_params, exact_match=False)

new_params = model.get_parameters()
# Check that all params except the final item in each state-dict are different.
for object_name in original_params:
# Skip optimizers (no valid comparison with just th.allclose)
if "optim" in object_name:
continue
# state-dicts use ordered dictionaries, so key order
# is guaranteed.
last_key = list(original_params[object_name].keys())[-1]
for k in original_params[object_name]:
if k == last_key:
# Should be same as before
assert th.allclose(
original_params[object_name][k], new_params[object_name][k]
), "Parameter changed despite not included in the loaded parameters."
else:
# Should be different
assert not th.allclose(
original_params[object_name][k], new_params[object_name][k]
), "Parameters did not change as expected."

params = new_params

Expand All @@ -81,14 +136,18 @@ def test_save_load(tmp_path, model_class):
assert model.policy.device.type == get_device(device).type

# check if params are still the same after load
new_params = model.policy.state_dict()
new_params = model.get_parameters()

# Check that all params are the same as before save load procedure now
for key in params:
assert new_params[key].device.type == get_device(device).type
assert th.allclose(
params[key].to("cpu"), new_params[key].to("cpu")
), "Model parameters not the same after save and load."
for object_name in new_params:
# Skip optimizers (no valid comparison with just th.allclose)
if "optim" in object_name:
continue
for key in params[object_name]:
assert new_params[object_name][key].device.type == get_device(device).type
assert th.allclose(
params[object_name][key].to("cpu"), new_params[object_name][key].to("cpu")
), "Model parameters not the same after save and load."

# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
Expand Down

0 comments on commit 9855486

Please sign in to comment.