Skip to content

Commit

Permalink
Allow to set a device when loading a model (#154)
Browse files Browse the repository at this point in the history
* Added a 'device' keyword argument to BaseAlgorithm.load().
Edited the save and load test to also test the load method with all possible devices.
Added the changes to the changelog

* improved the load test to ensure that the model loads to the correct device.

* improved the test: now the correctness is improved. If the get_device policy would change, it wouldn't break the test.

* Update tests/test_save_load.py

@araffin's suggestion during the PR process

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update tests/test_save_load.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Bug fixes: when comparing devices, comparing only device type since get_device() doesn't provide device index.
Now the code loads all of the model parameters from the saved state dict straight into the required device. (fixed load_from_zip_file).

* PR fixes: bug fix - a non-related test failed when running on GPU. updated the assertion to consider only types of devices. Also corrected a related bug in 'get_device()' method.

* Update changelog.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
leor-c and araffin authored Sep 20, 2020
1 parent 583d4b8 commit f5104a5
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 18 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ New Features:
^^^^^^^^^^^^^
- Added ``unwrap_vec_wrapper()`` to ``common.vec_env`` to extract ``VecEnvWrapper`` if needed
- Added ``StopTrainingOnMaxEpisodes`` to callback collection (@xicocaio)
- Added ``device`` keyword argument to ``BaseAlgorithm.load()`` (@liorcohen5)
- Callbacks have access to rollout collection locals as in SB2. (@PartiallyTyped)

Bug Fixes:
^^^^^^^^^^
- Fixed a bug where the environment was reset twice when using ``evaluate_policy``
- Fix logging of ``clip_fraction`` in PPO (@diditforlulz273)
- Fixed a bug where cuda support was wrongly checked when passing the GPU index, e.g., ``device="cuda:0"`` (@liorcohen5)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -399,4 +401,4 @@ And all the contributors:
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273
@diditforlulz273 @liorcohen5
9 changes: 6 additions & 3 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,16 +316,19 @@ def predict(
return self.policy.predict(observation, state, mask, deterministic)

@classmethod
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAlgorithm":
def load(
cls, load_path: str, env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", **kwargs
) -> "BaseAlgorithm":
"""
Load the model from a zip-file
:param load_path: the location of the saved data
:param env: the new environment to run the loaded model on
(can be None if you only need prediction from a trained model) has priority over any saved environment
:param device: (Union[th.device, str]) Device on which the code should run.
:param kwargs: extra arguments to change the model when loading
"""
data, params, tensors = load_from_zip_file(load_path)
data, params, tensors = load_from_zip_file(load_path, device=device)

if "policy_kwargs" in data:
for arg_to_remove in ["device"]:
Expand All @@ -352,7 +355,7 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAl
model = cls(
policy=data["policy_class"],
env=env,
device="auto",
device=device,
_init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
)

Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0)
def load_from_zip_file(
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
load_data: bool = True,
device: Union[th.device, str] = "auto",
verbose=0,
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
"""
Expand All @@ -360,13 +361,14 @@ def load_from_zip_file(
:param load_path: (str, pathlib.Path, io.BufferedIOBase) Where to load the model from
:param load_data: Whether we should load and return data
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
:param device: (Union[th.device, str]) Device on which the code should run.
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
and dict of extra tensors
"""
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")

# set device to cpu if cuda is not available
device = get_device()
device = get_device(device=device)

# Open the zip archive and load data
try:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
device = th.device(device)

# Cuda not available
if device == th.device("cuda") and not th.cuda.is_available():
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")

return device
Expand Down
2 changes: 1 addition & 1 deletion tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_predict(model_class, env_id, device):
# Test detection of different shapes by the predict method
model = model_class("MlpPolicy", env_id, device=device)
# Check that the policy is on the right device
assert get_device(device) == model.policy.device
assert get_device(device).type == model.policy.device.type

env = gym.make(env_id)
vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)])
Expand Down
35 changes: 24 additions & 11 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv

MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG]
Expand Down Expand Up @@ -70,21 +71,33 @@ def test_save_load(tmp_path, model_class):
# Check
model.save(tmp_path / "test_save.zip")
del model
model = model_class.load(str(tmp_path / "test_save.zip"), env=env)

# check if params are still the same after load
new_params = model.policy.state_dict()
# Check if the model loads as expected for every possible choice of device:
for device in ["auto", "cpu", "cuda"]:
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device)

# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."
# check if the model was loaded to the correct device
assert model.device.type == get_device(device).type
assert model.policy.device.type == get_device(device).type

# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# check if params are still the same after load
new_params = model.policy.state_dict()

# check if learn still works
model.learn(total_timesteps=1000, eval_freq=500)
# Check that all params are the same as before save load procedure now
for key in params:
assert new_params[key].device.type == get_device(device).type
assert th.allclose(
params[key].to("cpu"), new_params[key].to("cpu")
), "Model parameters not the same after save and load."

# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)

# check if learn still works
model.learn(total_timesteps=1000, eval_freq=500)

del model

# clear file from os
os.remove(tmp_path / "test_save.zip")
Expand Down

0 comments on commit f5104a5

Please sign in to comment.