Skip to content

Commit

Permalink
deepcopy model state_dict in tests (#2887)
Browse files Browse the repository at this point in the history
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
rohitgr7 and mergify[bot] committed Aug 8, 2020
1 parent 1bb268a commit 4d0406e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
import pickle

import cloudpickle
Expand All @@ -24,7 +25,7 @@ def __init__(self, *args, **kwargs):

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.saved_states.append(self.state_dict().copy())
self.saved_states.append(deepcopy(self.state_dict()))

class EarlyStoppingTestRestore(EarlyStopping):
def __init__(self, expected_state):
Expand Down
3 changes: 2 additions & 1 deletion tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
import pytest
import torch

Expand Down Expand Up @@ -33,7 +34,7 @@ def test_model_reset_correctly(tmpdir):
max_epochs=1,
)

before_state_dict = model.state_dict()
before_state_dict = deepcopy(model.state_dict())

_ = trainer.lr_find(model, num_training=5)

Expand Down
3 changes: 2 additions & 1 deletion tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
import pytest
import torch
from torch.utils.data import RandomSampler, SequentialSampler, DataLoader
Expand Down Expand Up @@ -148,7 +149,7 @@ def test_model_reset_correctly(tmpdir):
max_epochs=1,
)

before_state_dict = model.state_dict()
before_state_dict = deepcopy(model.state_dict())

trainer.scale_batch_size(model, max_trials=5)

Expand Down

0 comments on commit 4d0406e

Please sign in to comment.