Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.fit() returns last not best weights in ddp_spawn #2565

Merged
merged 13 commits into from
Jul 9, 2020
28 changes: 25 additions & 3 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class TrainerDDPMixin(ABC):
num_nodes: int
node_rank: int
tpu_cores: int
testing: bool

@property
@abstractmethod
Expand Down Expand Up @@ -555,15 +556,35 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
# continue training routine
results = self.run_pretrain_routine(model)

# persist info in ddp_spawn
self.__transfer_ddp_spawn_state_on_fit_end(model, q, results)

# clean up memory
torch.cuda.empty_cache()

if self.global_rank == 0 and self.distributed_backend not in ['ddp_spawn', 'ddp_cpu']:
return results

def __transfer_ddp_spawn_state_on_fit_end(self, model, q, results):
if not self.distributed_backend in ['ddp_spawn', 'ddp_cpu']:
return

# track the best model path
best_model_path = None
if self.checkpoint_callback is not None:
best_model_path = self.checkpoint_callback.best_model_path

if self.global_rank == 0 and q is not None:
q.put(self.checkpoint_callback.best_model_path)
rank_zero_warn('cleaning up ddp environment...')
q.put(best_model_path)
q.put(results)

if self.global_rank == 0 and self.distributed_backend != 'ddp_spawn':
return results
# save the last weights
last_path = None
if not self.testing:
last_path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
torch.save(model.state_dict(), last_path)
q.put(last_path)

def save_spawn_weights(self, model):
"""
Expand All @@ -574,6 +595,7 @@ def save_spawn_weights(self, model):
if self.is_global_zero:
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
self.save_checkpoint(path)
return path

def load_spawn_weights(self, original_model):
"""
Expand Down
13 changes: 9 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only
import warnings

# warnings to ignore
# warnings to ignore in trainer
warnings.filterwarnings('ignore', message='torch.distributed.reduce_op is deprecated, '
'please use torch.distributed.ReduceOp instead')

Expand Down Expand Up @@ -1063,9 +1063,14 @@ def __run_ddp_spawn(self, model, nprocs):
# restore main state with best weights
best_path = q.get()
results = q.get()
if best_path is not None and len(best_path) > 0:
self.checkpoint_callback.best_model_path = best_path
model.load_from_checkpoint(best_path)
last_path = q.get()

# transfer back the best path to the trainer
self.checkpoint_callback.best_model_path = best_path

# load last weights
if last_path is not None and not self.testing:
torch.load(last_path, map_location=lambda storage, loc: storage)

self.model = model
return results
Expand Down
21 changes: 21 additions & 0 deletions tests/models/test_test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@ def test_single_gpu_test(tmpdir):
results = trainer.test()
assert 'test_acc' in results

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.test(model)
assert 'test_acc' in results

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()

assert torch.all(torch.eq(old_weights, new_weights))


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_dp_test(tmpdir):
Expand All @@ -45,9 +52,16 @@ def test_dp_test(tmpdir):
results = trainer.test()
assert 'test_acc' in results

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.test(model)
assert 'test_acc' in results

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()

assert torch.all(torch.eq(old_weights, new_weights))


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_ddp_spawn_test(tmpdir):
Expand All @@ -67,5 +81,12 @@ def test_ddp_spawn_test(tmpdir):
results = trainer.test()
assert 'test_acc' in results

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.test(model)
assert 'test_acc' in results

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()

assert torch.all(torch.eq(old_weights, new_weights))