Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get/set parameters and review of saving and loading #138

Merged
merged 29 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1d5b328
Update comments and docstrings
Miffyli Aug 10, 2020
9021e80
Rename get_torch_variables to private and update docs
Miffyli Aug 16, 2020
4acfb7b
Clarify documentation on data, params and tensors
Miffyli Aug 16, 2020
a5fac22
Make excluded_save_params private and update docs
Miffyli Aug 16, 2020
80d7436
Update get_torch_variable_names to get_torch_save_params for description
Miffyli Aug 16, 2020
272d2fe
Simplify saving code and update docs on params vs tensors
Miffyli Aug 16, 2020
4c3cf07
Rename saved item tensors to pytorch_variables for clarity
Miffyli Aug 16, 2020
27c1b33
Reformat
araffin Aug 23, 2020
6d9fe59
Merge branch 'master' into review/save_load_params
araffin Aug 23, 2020
f704e53
Merge branch 'master' into review/save_load_params
araffin Aug 23, 2020
f481ca6
Merge branch 'master' into review/save_load_params
araffin Aug 24, 2020
be37b77
Merge branch 'master' into review/save_load_params
araffin Aug 27, 2020
ff95823
Merge branch 'master' into review/save_load_params
araffin Aug 29, 2020
d3d399e
Merge branch 'master' into review/save_load_params
araffin Sep 1, 2020
832b068
Merge branch 'master' into review/save_load_params
araffin Sep 15, 2020
d6233fa
Merge branch 'master' into review/save_load_params
araffin Sep 20, 2020
6443431
Fix a typo
Miffyli Sep 22, 2020
6c11724
Add get/set_parameters, update tests accordingly
Miffyli Sep 23, 2020
6e37ca5
Use f-strings for formatting
Miffyli Sep 23, 2020
4aba7dd
Fix load docstring
Miffyli Sep 23, 2020
a9b63e0
Reorganize functions in BaseClass
Miffyli Sep 23, 2020
27cf174
Update changelog
Miffyli Sep 23, 2020
9b6f979
Add library version to the stored models
Miffyli Sep 23, 2020
e82311d
Actually run isort this time
Miffyli Sep 23, 2020
cab0431
Merge branch 'master' into review/save_load_params
araffin Sep 23, 2020
07a6b9d
Fix flake8 complaints and also fix testing code
Miffyli Sep 23, 2020
ceee92f
Fix isort
Miffyli Sep 23, 2020
edbb13a
...and black
Miffyli Sep 23, 2020
3e2efff
Fix set_random_seed
araffin Sep 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 58 additions & 50 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,35 @@ def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.o
for optimizer in optimizers:
update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))

def _excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded from being
saved by pickling. E.g. replay buffers are skipped by default
as they take up a lot of space. PyTorch variables should be excluded
with this so they can be stored with ``th.save``.

:return: (List[str]) List of parameters that should be excluded from being saved with pickle.
"""
return ["policy", "device", "env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"]

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
"""
Get the name of the torch variables that will be saved with
PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default
pickling strategy. This is to handle device placement correctly.

Names can point to specific variables under classes, e.g.
"policy.optimizer" would point to ``optimizer`` object of ``self.policy``
if this object.

:return: (Tuple[List[str], List[str]])
List of Torch variables whose state dicts to save (e.g. th.nn.Modules),
and list of other Torch variables to store with ``th.save``.
"""
state_dicts = ["policy"]

return state_dicts, []

def get_env(self) -> Optional[VecEnv]:
"""
Returns the current environment (can be None if not defined).
Expand Down Expand Up @@ -255,19 +284,6 @@ def set_env(self, env: GymEnv) -> None:
self.n_envs = env.num_envs
self.env = env

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
Get the name of the torch variables that will be saved.
``th.save`` and ``th.load`` will be used with the right device
instead of the default pickling strategy.

:return: (Tuple[List[str], List[str]])
name of the variables with state dicts to save, name of additional torch tensors,
"""
state_dicts = ["policy"]

return state_dicts, []

@abstractmethod
def learn(
self,
Expand Down Expand Up @@ -325,28 +341,29 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAl
(can be None if you only need prediction from a trained model) has priority over any saved environment
:param kwargs: extra arguments to change the model when loading
"""
data, params, tensors = load_from_zip_file(load_path)
data, params, pytorch_variables = load_from_zip_file(load_path)

# Remove stored device information and replace with ours
if "policy_kwargs" in data:
for arg_to_remove in ["device"]:
if arg_to_remove in data["policy_kwargs"]:
del data["policy_kwargs"][arg_to_remove]
if "device" in "policy_kwargs":
del data["policy_kwargs"]["device"]

if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
raise ValueError(
f"The specified policy kwargs do not equal the stored policy kwargs."
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}"
)

# check if observation space and action space are part of the saved parameters
if "observation_space" not in data or "action_space" not in data:
raise KeyError("The observation_space and action_space were not given, can't verify new environments")
# check if given env is valid

if env is not None:
# Check if given env is valid
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
# if no new env was given use stored env if possible
if env is None and "env" in data:
env = data["env"]
else:
# Use stored env, if one exists. If not, continue as is (can be used for predict)
if "env" in data:
env = data["env"]

# noinspection PyArgumentList
model = cls(
Expand All @@ -366,10 +383,10 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAl
attr = recursive_getattr(model, name)
attr.load_state_dict(params[name])

# put tensors back in place
if tensors is not None:
for name in tensors:
recursive_setattr(model, name, tensors[name])
# py other pytorch variables back in place
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?

if pytorch_variables is not None:
for name in pytorch_variables:
recursive_setattr(model, name, pytorch_variables[name])

# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
Expand Down Expand Up @@ -510,15 +527,6 @@ def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.nd
if maybe_is_success is not None and dones[idx]:
self.ep_success_buffer.append(maybe_is_success)

def excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.

:return: ([str]) List of parameters that should be excluded from save
"""
return ["policy", "device", "env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"]

def save(
self,
path: Union[str, pathlib.Path, io.BufferedIOBase],
Expand All @@ -529,40 +537,40 @@ def save(
Save all the attributes of the object and the model parameters in a zip-file.

:param (Union[str, pathlib.Path, io.BufferedIOBase]): path to the file where the rl agent should be saved
:param exclude: name of parameters that should be excluded in addition to the default one
:param exclude: name of parameters that should be excluded in addition to the default ones
:param include: name of parameters that might be excluded but should be included anyway
"""
# copy parameter list so we don't mutate the original dict
# Copy parameter list so we don't mutate the original dict
data = self.__dict__.copy()

# Exclude is union of specified parameters (if any) and standard exclusions
if exclude is None:
exclude = []
exclude = set(exclude).union(self.excluded_save_params())
exclude = set(exclude).union(self._excluded_save_params())

# Do not exclude params if they are specifically included
if include is not None:
exclude = exclude.difference(include)

state_dicts_names, tensors_names = self.get_torch_variables()
# any params that are in the save vars must not be saved by data
torch_variables = state_dicts_names + tensors_names
for torch_var in torch_variables:
# we need to get only the name of the top most module as we'll remove that
state_dicts_names, torch_variable_names = self._get_torch_save_params()
all_pytorch_variables = state_dicts_names + torch_variable_names
for torch_var in all_pytorch_variables:
# We need to get only the name of the top most module as we'll remove that
var_name = torch_var.split(".")[0]
# Any params that are in the save vars must not be saved by data
exclude.add(var_name)

# Remove parameter entries of parameters which are to be excluded
for param_name in exclude:
data.pop(param_name, None)

# Build dict of tensor variables
tensors = None
if tensors_names is not None:
tensors = {}
for name in tensors_names:
# Build dict of torch variables
pytorch_variables = None
if torch_variable_names is not None:
pytorch_variables = {}
for name in torch_variable_names:
attr = recursive_getattr(self, name)
tensors[name] = attr
pytorch_variables[name] = attr

# Build dict of state_dicts
params_to_save = {}
Expand All @@ -571,4 +579,4 @@ def save(
# Retrieve state dict
params_to_save[name] = attr.state_dict()

save_to_zip_file(path, data=data, params=params_to_save, tensors=tensors)
save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)
5 changes: 1 addition & 4 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,7 @@ def learn(

return self

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]

return state_dicts, []
69 changes: 31 additions & 38 deletions stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,21 +284,20 @@ def save_to_zip_file(
save_path: Union[str, pathlib.Path, io.BufferedIOBase],
data: Dict[str, Any] = None,
params: Dict[str, Any] = None,
tensors: Dict[str, Any] = None,
pytorch_variables: Dict[str, Any] = None,
verbose=0,
) -> None:
"""
Save a model to a zip archive.
Save model data to a zip archive.

:param save_path: (Union[str, pathlib.Path, io.BufferedIOBase]) Where to store the model.
if save_path is a str or pathlib.Path ensures that the path actually exists.
:param data: Class parameters being stored.
:param data: Class parameters being stored (non-PyTorch variables)
:param params: Model parameters being stored expected to contain an entry for every
state_dict with its name and the state_dict.
:param tensors: Extra tensor variables expected to contain name and value of tensors
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information
"""

save_path = open_path(save_path, "w", verbose=0, suffix="zip")
# data/params can be None, so do not
# try to serialize them blindly
Expand All @@ -310,9 +309,9 @@ def save_to_zip_file(
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
if tensors is not None:
with archive.open("tensors.pth", mode="w") as tensors_file:
th.save(tensors, tensors_file)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w") as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w") as param_file:
Expand Down Expand Up @@ -358,8 +357,8 @@ def load_from_zip_file(
:param load_path: (str, pathlib.Path, io.BufferedIOBase) Where to load the model from
:param load_data: Whether we should load and return data
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
and dict of extra tensors
:return: (dict),(dict),(dict) Class parameters, model state_dicts (aka "params", dict of state_dict)
and dict of pytorch variables
"""
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")

Expand All @@ -374,44 +373,38 @@ def load_from_zip_file(
# zip archive, assume they were stored
# as None (_save_to_file_zip allows this).
data = None
tensors = None
pytorch_variables = None
params = {}

if "data" in namelist and load_data:
# Load class parameters and convert to string
# Load class parameters that are stored
# with either JSON or pickle (not PyTorch variables).
json_data = archive.read("data").decode()
data = json_to_data(json_data)

if "tensors.pth" in namelist and load_data:
# Load extra tensors
with archive.open("tensors.pth", mode="r") as tensor_file:
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
# Check for all .pth files and load them using th.load.
# "pytorch_variables.pth" stores PyTorch variables, and any other .pth
# files store state_dicts of variables with custom names (e.g. policy, policy.optimizer)
pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth"]
for file_path in pth_files:
with archive.open(file_path, mode="r") as param_file:
# File has to be seekable, but param_file is not, so load in BytesIO first
# fixed in python >= 3.7
file_content = io.BytesIO()
file_content.write(tensor_file.read())
file_content.write(param_file.read())
# go to start of file
file_content.seek(0)
# load the parameters with the right ``map_location``
tensors = th.load(file_content, map_location=device)

# check for all other .pth files
other_files = [
file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"
]
# if there are any other files which end with .pth and aren't "params.pth"
# assume that they each are optimizer parameters
if len(other_files) > 0:
for file_path in other_files:
with archive.open(file_path, mode="r") as opt_param_file:
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
# fixed in python >= 3.7
file_content = io.BytesIO()
file_content.write(opt_param_file.read())
# go to start of file
file_content.seek(0)
# load the parameters with the right ``map_location``
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
# Load the parameters with the right ``map_location``.
# Remove ".pth" ending with splitext
th_object = th.load(file_content, map_location=device)
if file_path == "pytorch_variables.pth":
# PyTorch variables (not state_dicts)
pytorch_variables = th_object
else:
# State dicts. Store into params dictionary
# with same name as in .zip file (without .pth)
params[os.path.splitext(file_path)[0]] = th_object
except zipfile.BadZipFile:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
return data, params, tensors
return data, params, pytorch_variables
16 changes: 3 additions & 13 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,20 +231,10 @@ def learn(
reset_num_timesteps=reset_num_timesteps,
)

def excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.

:return: (List[str]) List of parameters that should be excluded from save
"""
# Exclude aliases
return super(DQN, self).excluded_save_params() + ["q_net", "q_net_target"]
def _excluded_save_params(self) -> List[str]:
return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"]

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]

return state_dicts, []
24 changes: 7 additions & 17 deletions stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,24 +293,14 @@ def learn(
reset_num_timesteps=reset_num_timesteps,
)

def excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.

:return: (List[str]) List of parameters that should be excluded from save
"""
# Exclude aliases
return super(SAC, self).excluded_save_params() + ["actor", "critic", "critic_target"]

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
def _excluded_save_params(self) -> List[str]:
return super(SAC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
saved_tensors = ["log_ent_coef"]
saved_pytorch_variables = ["log_ent_coef"]
if self.ent_coef_optimizer is not None:
state_dicts.append("ent_coef_optimizer")
else:
saved_tensors.append("ent_coef_tensor")
return state_dicts, saved_tensors
saved_pytorch_variables.append("ent_coef_tensor")
return state_dicts, saved_pytorch_variables
18 changes: 4 additions & 14 deletions stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,9 @@ def learn(
reset_num_timesteps=reset_num_timesteps,
)

def excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.

:return: (List[str]) List of parameters that should be excluded from save
"""
# Exclude aliases
return super(TD3, self).excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
def _excluded_save_params(self) -> List[str]:
return super(TD3, self)._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
return state_dicts, []