Skip to content

Commit

Permalink
.fit() returns last not best weights in ddp_spawn (#2565)
Browse files Browse the repository at this point in the history
* added base tests for tpu

* added base tests for tpu

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint
  • Loading branch information
williamFalcon committed Jul 9, 2020
1 parent e1bc208 commit 4bbcfa0
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 7 deletions.
28 changes: 25 additions & 3 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class TrainerDDPMixin(ABC):
num_nodes: int
node_rank: int
tpu_cores: int
testing: bool

@property
@abstractmethod
Expand Down Expand Up @@ -555,15 +556,35 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
# continue training routine
results = self.run_pretrain_routine(model)

# persist info in ddp_spawn
self.__transfer_ddp_spawn_state_on_fit_end(model, q, results)

# clean up memory
torch.cuda.empty_cache()

if self.global_rank == 0 and self.distributed_backend not in ['ddp_spawn', 'ddp_cpu']:
return results

def __transfer_ddp_spawn_state_on_fit_end(self, model, q, results):
if not self.distributed_backend in ['ddp_spawn', 'ddp_cpu']:
return

# track the best model path
best_model_path = None
if self.checkpoint_callback is not None:
best_model_path = self.checkpoint_callback.best_model_path

if self.global_rank == 0 and q is not None:
q.put(self.checkpoint_callback.best_model_path)
rank_zero_warn('cleaning up ddp environment...')
q.put(best_model_path)
q.put(results)

if self.global_rank == 0 and self.distributed_backend != 'ddp_spawn':
return results
# save the last weights
last_path = None
if not self.testing:
last_path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
torch.save(model.state_dict(), last_path)
q.put(last_path)

def save_spawn_weights(self, model):
"""
Expand All @@ -574,6 +595,7 @@ def save_spawn_weights(self, model):
if self.is_global_zero:
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
self.save_checkpoint(path)
return path

def load_spawn_weights(self, original_model):
"""
Expand Down
13 changes: 9 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only
import warnings

# warnings to ignore
# warnings to ignore in trainer
warnings.filterwarnings('ignore', message='torch.distributed.reduce_op is deprecated, '
'please use torch.distributed.ReduceOp instead')

Expand Down Expand Up @@ -1063,9 +1063,14 @@ def __run_ddp_spawn(self, model, nprocs):
# restore main state with best weights
best_path = q.get()
results = q.get()
if best_path is not None and len(best_path) > 0:
self.checkpoint_callback.best_model_path = best_path
model.load_from_checkpoint(best_path)
last_path = q.get()

# transfer back the best path to the trainer
self.checkpoint_callback.best_model_path = best_path

# load last weights
if last_path is not None and not self.testing:
torch.load(last_path, map_location=lambda storage, loc: storage)

self.model = model
return results
Expand Down
21 changes: 21 additions & 0 deletions tests/models/test_test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@ def test_single_gpu_test(tmpdir):
results = trainer.test()
assert 'test_acc' in results

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.test(model)
assert 'test_acc' in results

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()

assert torch.all(torch.eq(old_weights, new_weights))


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_dp_test(tmpdir):
Expand All @@ -45,9 +52,16 @@ def test_dp_test(tmpdir):
results = trainer.test()
assert 'test_acc' in results

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.test(model)
assert 'test_acc' in results

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()

assert torch.all(torch.eq(old_weights, new_weights))


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_ddp_spawn_test(tmpdir):
Expand All @@ -67,5 +81,12 @@ def test_ddp_spawn_test(tmpdir):
results = trainer.test()
assert 'test_acc' in results

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.test(model)
assert 'test_acc' in results

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()

assert torch.all(torch.eq(old_weights, new_weights))

0 comments on commit 4bbcfa0

Please sign in to comment.