Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get/set parameters and review of saving and loading #138

Merged
merged 29 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1d5b328
Update comments and docstrings
Miffyli Aug 10, 2020
9021e80
Rename get_torch_variables to private and update docs
Miffyli Aug 16, 2020
4acfb7b
Clarify documentation on data, params and tensors
Miffyli Aug 16, 2020
a5fac22
Make excluded_save_params private and update docs
Miffyli Aug 16, 2020
80d7436
Update get_torch_variable_names to get_torch_save_params for description
Miffyli Aug 16, 2020
272d2fe
Simplify saving code and update docs on params vs tensors
Miffyli Aug 16, 2020
4c3cf07
Rename saved item tensors to pytorch_variables for clarity
Miffyli Aug 16, 2020
27c1b33
Reformat
araffin Aug 23, 2020
6d9fe59
Merge branch 'master' into review/save_load_params
araffin Aug 23, 2020
f704e53
Merge branch 'master' into review/save_load_params
araffin Aug 23, 2020
f481ca6
Merge branch 'master' into review/save_load_params
araffin Aug 24, 2020
be37b77
Merge branch 'master' into review/save_load_params
araffin Aug 27, 2020
ff95823
Merge branch 'master' into review/save_load_params
araffin Aug 29, 2020
d3d399e
Merge branch 'master' into review/save_load_params
araffin Sep 1, 2020
832b068
Merge branch 'master' into review/save_load_params
araffin Sep 15, 2020
d6233fa
Merge branch 'master' into review/save_load_params
araffin Sep 20, 2020
6443431
Fix a typo
Miffyli Sep 22, 2020
6c11724
Add get/set_parameters, update tests accordingly
Miffyli Sep 23, 2020
6e37ca5
Use f-strings for formatting
Miffyli Sep 23, 2020
4aba7dd
Fix load docstring
Miffyli Sep 23, 2020
a9b63e0
Reorganize functions in BaseClass
Miffyli Sep 23, 2020
27cf174
Update changelog
Miffyli Sep 23, 2020
9b6f979
Add library version to the stored models
Miffyli Sep 23, 2020
e82311d
Actually run isort this time
Miffyli Sep 23, 2020
cab0431
Merge branch 'master' into review/save_load_params
araffin Sep 23, 2020
07a6b9d
Fix flake8 complaints and also fix testing code
Miffyli Sep 23, 2020
ceee92f
Fix isort
Miffyli Sep 23, 2020
edbb13a
...and black
Miffyli Sep 23, 2020
3e2efff
Fix set_random_seed
araffin Sep 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ New Features:
- Added ``StopTrainingOnMaxEpisodes`` to callback collection (@xicocaio)
- Added ``device`` keyword argument to ``BaseAlgorithm.load()`` (@liorcohen5)
- Callbacks have access to rollout collection locals as in SB2. (@PartiallyTyped)
- Added ``get_parameters`` and ``set_parameters`` for accessing/setting parameters of the agent
- Added actor/critic loss logging for TD3. (@mloo3)

Bug Fixes:
^^^^^^^^^^
- Fixed a bug where the environment was reset twice when using ``evaluate_policy``
- Fix logging of ``clip_fraction`` in PPO (@diditforlulz273)
- Fixed a bug where cuda support was wrongly checked when passing the GPU index, e.g., ``device="cuda:0"`` (@liorcohen5)
- Fixed a bug when the random seed was not properly set on cuda when passing the GPU index

Deprecations:
^^^^^^^^^^^^^
Expand All @@ -33,6 +35,14 @@ Others:
- Fix type annotation of ``make_vec_env`` (@ManifoldFR)
- Removed ``AlreadySteppingError`` and ``NotSteppingError`` that were not used
- Fixed typos in SAC and TD3
- Rename ``BaseClass.get_torch_variables`` -> ``BaseClass._get_torch_save_params`` and
``BaseClass.excluded_save_params`` -> ``BaseClass._excluded_save_params``
- Reorganized functions for clarity in ``BaseClass`` (save/load functions close to each other, private
functions at top)
- Clarified docstrings on what is saved and loaded to/from files
- Renamed saved items ``tensors`` to ``pytorch_variables`` for clarity
- Simplified ``save_to_zip_file`` function by removing duplicate code
- Store library version along with the saved models

Documentation:
^^^^^^^^^^^^^^
Expand Down
472 changes: 285 additions & 187 deletions stable_baselines3/common/base_class.py

Large diffs are not rendered by default.

5 changes: 1 addition & 4 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,7 @@ def learn(

return self

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]

return state_dicts, []
72 changes: 34 additions & 38 deletions stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import cloudpickle
import torch as th

import stable_baselines3
from stable_baselines3.common.type_aliases import TensorDict
from stable_baselines3.common.utils import get_device

Expand Down Expand Up @@ -284,21 +285,20 @@ def save_to_zip_file(
save_path: Union[str, pathlib.Path, io.BufferedIOBase],
data: Dict[str, Any] = None,
params: Dict[str, Any] = None,
tensors: Dict[str, Any] = None,
pytorch_variables: Dict[str, Any] = None,
verbose=0,
) -> None:
"""
Save a model to a zip archive.
Save model data to a zip archive.

:param save_path: (Union[str, pathlib.Path, io.BufferedIOBase]) Where to store the model.
if save_path is a str or pathlib.Path ensures that the path actually exists.
:param data: Class parameters being stored.
:param data: Class parameters being stored (non-PyTorch variables)
:param params: Model parameters being stored expected to contain an entry for every
state_dict with its name and the state_dict.
:param tensors: Extra tensor variables expected to contain name and value of tensors
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information
"""

save_path = open_path(save_path, "w", verbose=0, suffix="zip")
# data/params can be None, so do not
# try to serialize them blindly
Expand All @@ -310,13 +310,15 @@ def save_to_zip_file(
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
if tensors is not None:
with archive.open("tensors.pth", mode="w") as tensors_file:
th.save(tensors, tensors_file)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w") as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w") as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", stable_baselines3.__version__)


def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj, verbose=0) -> None:
Expand Down Expand Up @@ -362,8 +364,8 @@ def load_from_zip_file(
:param load_data: Whether we should load and return data
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
:param device: (Union[th.device, str]) Device on which the code should run.
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
and dict of extra tensors
:return: (dict),(dict),(dict) Class parameters, model state_dicts (aka "params", dict of state_dict)
and dict of pytorch variables
"""
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")

Expand All @@ -378,44 +380,38 @@ def load_from_zip_file(
# zip archive, assume they were stored
# as None (_save_to_file_zip allows this).
data = None
tensors = None
pytorch_variables = None
params = {}

if "data" in namelist and load_data:
# Load class parameters and convert to string
# Load class parameters that are stored
# with either JSON or pickle (not PyTorch variables).
json_data = archive.read("data").decode()
data = json_to_data(json_data)

if "tensors.pth" in namelist and load_data:
# Load extra tensors
with archive.open("tensors.pth", mode="r") as tensor_file:
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
# Check for all .pth files and load them using th.load.
# "pytorch_variables.pth" stores PyTorch variables, and any other .pth
# files store state_dicts of variables with custom names (e.g. policy, policy.optimizer)
pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth"]
for file_path in pth_files:
with archive.open(file_path, mode="r") as param_file:
# File has to be seekable, but param_file is not, so load in BytesIO first
# fixed in python >= 3.7
file_content = io.BytesIO()
file_content.write(tensor_file.read())
file_content.write(param_file.read())
# go to start of file
file_content.seek(0)
# load the parameters with the right ``map_location``
tensors = th.load(file_content, map_location=device)

# check for all other .pth files
other_files = [
file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"
]
# if there are any other files which end with .pth and aren't "params.pth"
# assume that they each are optimizer parameters
if len(other_files) > 0:
for file_path in other_files:
with archive.open(file_path, mode="r") as opt_param_file:
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
# fixed in python >= 3.7
file_content = io.BytesIO()
file_content.write(opt_param_file.read())
# go to start of file
file_content.seek(0)
# load the parameters with the right ``map_location``
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
# Load the parameters with the right ``map_location``.
# Remove ".pth" ending with splitext
th_object = th.load(file_content, map_location=device)
if file_path == "pytorch_variables.pth":
# PyTorch variables (not state_dicts)
pytorch_variables = th_object
else:
# State dicts. Store into params dictionary
# with same name as in .zip file (without .pth)
params[os.path.splitext(file_path)[0]] = th_object
except zipfile.BadZipFile:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
return data, params, tensors
return data, params, pytorch_variables
16 changes: 3 additions & 13 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,20 +231,10 @@ def learn(
reset_num_timesteps=reset_num_timesteps,
)

def excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.

:return: (List[str]) List of parameters that should be excluded from save
"""
# Exclude aliases
return super(DQN, self).excluded_save_params() + ["q_net", "q_net_target"]
def _excluded_save_params(self) -> List[str]:
return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"]

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]

return state_dicts, []
24 changes: 7 additions & 17 deletions stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,24 +293,14 @@ def learn(
reset_num_timesteps=reset_num_timesteps,
)

def excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.

:return: (List[str]) List of parameters that should be excluded from save
"""
# Exclude aliases
return super(SAC, self).excluded_save_params() + ["actor", "critic", "critic_target"]

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
def _excluded_save_params(self) -> List[str]:
return super(SAC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
saved_tensors = ["log_ent_coef"]
saved_pytorch_variables = ["log_ent_coef"]
if self.ent_coef_optimizer is not None:
state_dicts.append("ent_coef_optimizer")
else:
saved_tensors.append("ent_coef_tensor")
return state_dicts, saved_tensors
saved_pytorch_variables.append("ent_coef_tensor")
return state_dicts, saved_pytorch_variables
18 changes: 4 additions & 14 deletions stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,19 +205,9 @@ def learn(
reset_num_timesteps=reset_num_timesteps,
)

def excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.

:return: (List[str]) List of parameters that should be excluded from save
"""
# Exclude aliases
return super(TD3, self).excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
def _excluded_save_params(self) -> List[str]:
return super(TD3, self)._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
return state_dicts, []
93 changes: 76 additions & 17 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pathlib
import warnings
from collections import OrderedDict
from copy import deepcopy

import gym
Expand Down Expand Up @@ -33,7 +34,7 @@ def select_env(model_class: BaseAlgorithm) -> gym.Env:
def test_save_load(tmp_path, model_class):
"""
Test if 'save' and 'load' saves and loads model correctly
and if 'load_parameters' and 'get_policy_parameters' work correctly
and if 'get_parameters' and 'set_parameters' and work correctly.

''warning does not test function of optimizer parameter load

Expand All @@ -49,19 +50,73 @@ def test_save_load(tmp_path, model_class):
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)

# Get dictionary of current parameters
params = deepcopy(model.policy.state_dict())
# Get parameters of different objects
# deepcopy to avoid referencing to tensors we are about to modify
original_params = deepcopy(model.get_parameters())

# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
# Test different error cases of set_parameters.
# Test that invalid object names throw errors
invalid_object_params = deepcopy(original_params)
invalid_object_params["I_should_not_be_a_valid_object"] = "and_I_am_an_invalid_tensor"
with pytest.raises(ValueError):
model.set_parameters(invalid_object_params, exact_match=True)
with pytest.raises(ValueError):
model.set_parameters(invalid_object_params, exact_match=False)

# Update model parameters with the new random values
model.policy.load_state_dict(random_params)
# Test that exact_match catches when something was missed.
missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1])
with pytest.raises(ValueError):
model.set_parameters(missing_object_params, exact_match=True)

# Test that exact_match catches when something inside state-dict
# is missing but we have exact_match.
missing_state_dict_tensor_params = {}
for object_name in original_params:
object_params = {}
missing_state_dict_tensor_params[object_name] = object_params
# Skip last item in state-dict
for k, v in list(original_params[object_name].items())[:-1]:
object_params[k] = v
with pytest.raises(RuntimeError):
# PyTorch load_state_dict throws RuntimeError if strict but
# invalid state-dict.
model.set_parameters(missing_state_dict_tensor_params, exact_match=True)

# Test that parameters do indeed change.
random_params = {}
for object_name, params in original_params.items():
# Do not randomize optimizer parameters (custom layout)
if "optim" in object_name:
random_params[object_name] = params
else:
# Again, skip the last item in state-dict
random_params[object_name] = OrderedDict(
(param_name, th.rand_like(param)) for param_name, param in list(params.items())[:-1]
)

new_params = model.policy.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
# Update model parameters with the new random values
model.set_parameters(random_params, exact_match=False)

new_params = model.get_parameters()
# Check that all params except the final item in each state-dict are different.
for object_name in original_params:
# Skip optimizers (no valid comparison with just th.allclose)
if "optim" in object_name:
continue
# state-dicts use ordered dictionaries, so key order
# is guaranteed.
last_key = list(original_params[object_name].keys())[-1]
for k in original_params[object_name]:
if k == last_key:
# Should be same as before
assert th.allclose(
original_params[object_name][k], new_params[object_name][k]
), "Parameter changed despite not included in the loaded parameters."
else:
# Should be different
assert not th.allclose(
original_params[object_name][k], new_params[object_name][k]
), "Parameters did not change as expected."

params = new_params

Expand All @@ -81,14 +136,18 @@ def test_save_load(tmp_path, model_class):
assert model.policy.device.type == get_device(device).type

# check if params are still the same after load
new_params = model.policy.state_dict()
new_params = model.get_parameters()

# Check that all params are the same as before save load procedure now
for key in params:
assert new_params[key].device.type == get_device(device).type
assert th.allclose(
params[key].to("cpu"), new_params[key].to("cpu")
), "Model parameters not the same after save and load."
for object_name in new_params:
# Skip optimizers (no valid comparison with just th.allclose)
if "optim" in object_name:
continue
for key in params[object_name]:
assert new_params[object_name][key].device.type == get_device(device).type
assert th.allclose(
params[object_name][key].to("cpu"), new_params[object_name][key].to("cpu")
), "Model parameters not the same after save and load."

# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
Expand Down