Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get/set parameters and review of saving and loading #138

Merged
merged 29 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1d5b328
Update comments and docstrings
Miffyli Aug 10, 2020
9021e80
Rename get_torch_variables to private and update docs
Miffyli Aug 16, 2020
4acfb7b
Clarify documentation on data, params and tensors
Miffyli Aug 16, 2020
a5fac22
Make excluded_save_params private and update docs
Miffyli Aug 16, 2020
80d7436
Update get_torch_variable_names to get_torch_save_params for description
Miffyli Aug 16, 2020
272d2fe
Simplify saving code and update docs on params vs tensors
Miffyli Aug 16, 2020
4c3cf07
Rename saved item tensors to pytorch_variables for clarity
Miffyli Aug 16, 2020
27c1b33
Reformat
araffin Aug 23, 2020
6d9fe59
Merge branch 'master' into review/save_load_params
araffin Aug 23, 2020
f704e53
Merge branch 'master' into review/save_load_params
araffin Aug 23, 2020
f481ca6
Merge branch 'master' into review/save_load_params
araffin Aug 24, 2020
be37b77
Merge branch 'master' into review/save_load_params
araffin Aug 27, 2020
ff95823
Merge branch 'master' into review/save_load_params
araffin Aug 29, 2020
d3d399e
Merge branch 'master' into review/save_load_params
araffin Sep 1, 2020
832b068
Merge branch 'master' into review/save_load_params
araffin Sep 15, 2020
d6233fa
Merge branch 'master' into review/save_load_params
araffin Sep 20, 2020
6443431
Fix a typo
Miffyli Sep 22, 2020
6c11724
Add get/set_parameters, update tests accordingly
Miffyli Sep 23, 2020
6e37ca5
Use f-strings for formatting
Miffyli Sep 23, 2020
4aba7dd
Fix load docstring
Miffyli Sep 23, 2020
a9b63e0
Reorganize functions in BaseClass
Miffyli Sep 23, 2020
27cf174
Update changelog
Miffyli Sep 23, 2020
9b6f979
Add library version to the stored models
Miffyli Sep 23, 2020
e82311d
Actually run isort this time
Miffyli Sep 23, 2020
cab0431
Merge branch 'master' into review/save_load_params
araffin Sep 23, 2020
07a6b9d
Fix flake8 complaints and also fix testing code
Miffyli Sep 23, 2020
ceee92f
Fix isort
Miffyli Sep 23, 2020
edbb13a
...and black
Miffyli Sep 23, 2020
3e2efff
Fix set_random_seed
araffin Sep 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 174 additions & 59 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,22 @@
import torch as th

from stable_baselines3.common import logger, utils
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
from stable_baselines3.common.callbacks import (
BaseCallback,
CallbackList,
ConvertCallback,
EvalCallback,
)
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
from stable_baselines3.common.preprocessing import is_image_space
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.save_util import (
load_from_zip_file,
recursive_getattr,
recursive_setattr,
save_to_zip_file,
)
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import (
check_for_correct_spaces,
Expand All @@ -26,7 +36,13 @@
set_random_seed,
update_learning_rate,
)
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecNormalize, VecTransposeImage, unwrap_vec_normalize
from stable_baselines3.common.vec_env import (
DummyVecEnv,
VecEnv,
VecNormalize,
VecTransposeImage,
unwrap_vec_normalize,
)


def maybe_make_env(env: Union[GymEnv, str, None], monitor_wrapper: bool, verbose: int) -> Optional[GymEnv]:
Expand Down Expand Up @@ -221,6 +237,43 @@ def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.o
for optimizer in optimizers:
update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))

def _excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded from being
saved by pickling. E.g. replay buffers are skipped by default
as they take up a lot of space. PyTorch variables should be excluded
with this so they can be stored with ``th.save``.

:return: (List[str]) List of parameters that should be excluded from being saved with pickle.
"""
return [
"policy",
"device",
"env",
"eval_env",
"replay_buffer",
"rollout_buffer",
"_vec_normalize_env",
]

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
"""
Get the name of the torch variables that will be saved with
PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default
pickling strategy. This is to handle device placement correctly.

Names can point to specific variables under classes, e.g.
"policy.optimizer" would point to ``optimizer`` object of ``self.policy``
if this object.

:return: (Tuple[List[str], List[str]])
List of Torch variables whose state dicts to save (e.g. th.nn.Modules),
and list of other Torch variables to store with ``th.save``.
"""
state_dicts = ["policy"]

return state_dicts, []

def get_env(self) -> Optional[VecEnv]:
"""
Returns the current environment (can be None if not defined).
Expand Down Expand Up @@ -255,19 +308,6 @@ def set_env(self, env: GymEnv) -> None:
self.n_envs = env.num_envs
self.env = env

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
Get the name of the torch variables that will be saved.
``th.save`` and ``th.load`` will be used with the right device
instead of the default pickling strategy.

:return: (Tuple[List[str], List[str]])
name of the variables with state dicts to save, name of additional torch tensors,
"""
state_dicts = ["policy"]

return state_dicts, []

@abstractmethod
def learn(
self,
Expand Down Expand Up @@ -315,41 +355,116 @@ def predict(
"""
return self.policy.predict(observation, state, mask, deterministic)

def set_parameters(
self,
load_path_or_dict: Union[str, Dict[str, Dict]],
exact_match: bool = True,
device: Union[th.device, str] = "auto",
):
"""
Load parameters from a given zip-file or a nested dictionary containing parameters for
different modules (see ``get_parameters``).

:param load_path_or_iter: Location of the saved data (path or file-like, see ``save``), or a nested
dictionary containing nn.Module parameters used by the policy. The dictionary maps
object names to a state-dictionary returned by ``torch.nn.Module.state_dict()``.
:param exact_match: If True, the given parameters should include parameters for each
module and each of their parameters, otherwise raises an Exception. If set to False, this
can be used to update only specific parameters.
:param device: (Union[th.device, str]) Device on which the code should run.
"""
params = None
if isinstance(load_path_or_dict, dict):
params = load_path_or_dict
else:
_, params, _ = load_from_zip_file(load_path_or_dict, device=device)

# Keep track which objects were updated.
# `_get_torch_save_params` returns [params, other_pytorch_variables].
# We are only interested in former here.
objects_needing_update = set(self._get_torch_save_params()[0])
updated_objects = set()

for name in params:
attr = None
try:
attr = recursive_getattr(self, name)
except Exception:
# What errors recursive_getattr could throw? KeyError, but
# possible something else too (e.g. if key is an int?).
# Catch anything for now.
raise ValueError("Key {} is an invalid object name.".format(name))
Miffyli marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(attr, th.optim.Optimizer):
# Optimizers do not support "strict" keyword...
# Seems like they will just replace the whole
# optimizer state with the given one.
# On top of this, optimizer state-dict
# seems to change (e.g. first ``optim.step()``),
# which makes comparing state dictionary keys
# invalid (there is also a nesting of dictionaries
# with lists with dictionaries with ...), adding to the
# mess.
#
# TL;DR: We might not be able to reliably say
# if given state-dict is missing keys.
#
# Solution: Just load the state-dict as is, and trust
# the user has provided a sensible state dictionary.
attr.load_state_dict(params[name])
else:
# Assume attr is th.nn.Module
attr.load_state_dict(params[name], strict=exact_match)
updated_objects.add(name)

if exact_match and updated_objects != objects_needing_update:
raise ValueError(
"Names of parameters do not match agents' parameters: expected {}, got {}".format(
Miffyli marked this conversation as resolved.
Show resolved Hide resolved
objects_needing_update, updated_objects
)
)

@classmethod
def load(
cls, load_path: str, env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", **kwargs
cls,
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,
device: Union[th.device, str] = "auto",
**kwargs,
) -> "BaseAlgorithm":
"""
Load the model from a zip-file

:param load_path: the location of the saved data
param (Union[str, pathlib.Path, io.BufferedIOBase]): path to the file (or a file-like) where to load the
agent from
:param env: the new environment to run the loaded model on
(can be None if you only need prediction from a trained model) has priority over any saved environment
:param device: (Union[th.device, str]) Device on which the code should run.
:param kwargs: extra arguments to change the model when loading
"""
data, params, tensors = load_from_zip_file(load_path, device=device)
data, params, pytorch_variables = load_from_zip_file(path, device=device)

# Remove stored device information and replace with ours
if "policy_kwargs" in data:
for arg_to_remove in ["device"]:
if arg_to_remove in data["policy_kwargs"]:
del data["policy_kwargs"][arg_to_remove]
if "device" in "policy_kwargs":
del data["policy_kwargs"]["device"]

if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
raise ValueError(
f"The specified policy kwargs do not equal the stored policy kwargs."
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}"
)

# check if observation space and action space are part of the saved parameters
if "observation_space" not in data or "action_space" not in data:
raise KeyError("The observation_space and action_space were not given, can't verify new environments")
# check if given env is valid

if env is not None:
# Check if given env is valid
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
# if no new env was given use stored env if possible
if env is None and "env" in data:
env = data["env"]
else:
# Use stored env, if one exists. If not, continue as is (can be used for predict)
if "env" in data:
env = data["env"]

# noinspection PyArgumentList
model = cls(
Expand All @@ -365,14 +480,12 @@ def load(
model._setup_model()

# put state_dicts back in place
for name in params:
attr = recursive_getattr(model, name)
attr.load_state_dict(params[name])
model.set_parameters(params, exact_match=True, device=device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why exact_match is hardcoded?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using exact_match=False would mean some parameters were missing in the saved model file which should not happen unless someone modifies the file. Laying out the hardcoded parameter like this is to signal that we want to make sure every parameter is updated as it was saved, and that nothing is missing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, fair enough.
What can happen is also that parameters are renamed between versions (it happened to me after refactoring the continuous critic, the name of the parameters were not the same)


# put tensors back in place
if tensors is not None:
for name in tensors:
recursive_setattr(model, name, tensors[name])
# put other pytorch variables back in place
if pytorch_variables is not None:
for name in pytorch_variables:
recursive_setattr(model, name, pytorch_variables[name])

# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
Expand Down Expand Up @@ -513,14 +626,20 @@ def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.nd
if maybe_is_success is not None and dones[idx]:
self.ep_success_buffer.append(maybe_is_success)

def excluded_save_params(self) -> List[str]:
def get_parameters(self):
"""
Returns the names of the parameters that should be excluded by default
when saving the model.
Return the parameters of the agent. This includes parameters from different networks, e.g.
critics (value functions) and policies (pi functions).

:return: ([str]) List of parameters that should be excluded from save
:return: (Dict[str, Dict]) Mapping of from names of the objects to PyTorch state-dicts.
"""
return ["policy", "device", "env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"]
state_dicts_names, _ = self._get_torch_save_params()
params = {}
for name in state_dicts_names:
attr = recursive_getattr(self, name)
# Retrieve state dict
params[name] = attr.state_dict()
return params

def save(
self,
Expand All @@ -532,46 +651,42 @@ def save(
Save all the attributes of the object and the model parameters in a zip-file.

:param (Union[str, pathlib.Path, io.BufferedIOBase]): path to the file where the rl agent should be saved
:param exclude: name of parameters that should be excluded in addition to the default one
:param exclude: name of parameters that should be excluded in addition to the default ones
:param include: name of parameters that might be excluded but should be included anyway
"""
# copy parameter list so we don't mutate the original dict
# Copy parameter list so we don't mutate the original dict
data = self.__dict__.copy()

# Exclude is union of specified parameters (if any) and standard exclusions
if exclude is None:
exclude = []
exclude = set(exclude).union(self.excluded_save_params())
exclude = set(exclude).union(self._excluded_save_params())

# Do not exclude params if they are specifically included
if include is not None:
exclude = exclude.difference(include)

state_dicts_names, tensors_names = self.get_torch_variables()
# any params that are in the save vars must not be saved by data
torch_variables = state_dicts_names + tensors_names
for torch_var in torch_variables:
# we need to get only the name of the top most module as we'll remove that
state_dicts_names, torch_variable_names = self._get_torch_save_params()
all_pytorch_variables = state_dicts_names + torch_variable_names
for torch_var in all_pytorch_variables:
# We need to get only the name of the top most module as we'll remove that
var_name = torch_var.split(".")[0]
# Any params that are in the save vars must not be saved by data
exclude.add(var_name)

# Remove parameter entries of parameters which are to be excluded
for param_name in exclude:
data.pop(param_name, None)

# Build dict of tensor variables
tensors = None
if tensors_names is not None:
tensors = {}
for name in tensors_names:
# Build dict of torch variables
pytorch_variables = None
if torch_variable_names is not None:
pytorch_variables = {}
for name in torch_variable_names:
attr = recursive_getattr(self, name)
tensors[name] = attr
pytorch_variables[name] = attr

# Build dict of state_dicts
params_to_save = {}
for name in state_dicts_names:
attr = recursive_getattr(self, name)
# Retrieve state dict
params_to_save[name] = attr.state_dict()
params_to_save = self.get_parameters()

save_to_zip_file(path, data=data, params=params_to_save, tensors=tensors)
save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)
5 changes: 1 addition & 4 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,7 @@ def learn(

return self

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]

return state_dicts, []
Loading