Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.fit() returns last not best weights in ddp_spawn #2565

Merged
merged 13 commits into from
Jul 9, 2020
23 changes: 20 additions & 3 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class TrainerDDPMixin(ABC):
num_nodes: int
node_rank: int
tpu_cores: int
testing: bool

@property
@abstractmethod
Expand Down Expand Up @@ -555,15 +556,31 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
# continue training routine
results = self.run_pretrain_routine(model)

# persist info in ddp_spawn
self.__transfer_ddp_spawn_state_on_fit_end(model, q, results)

# clean up memory
torch.cuda.empty_cache()

if self.global_rank == 0 and self.distributed_backend != 'ddp_spawn':
return results

def __transfer_ddp_spawn_state_on_fit_end(self, model, q, results):
if not self.distributed_backend == 'ddp_spawn':
return

# track the best model path
best_model_path = None
if self.checkpoint_callback is not None:
best_model_path = self.checkpoint_callback.best_model_path

if self.global_rank == 0 and q is not None:
q.put(self.checkpoint_callback.best_model_path)
rank_zero_warn('cleaning up ddp environment...')
q.put(best_model_path)
q.put(results)

if self.global_rank == 0 and self.distributed_backend != 'ddp_spawn':
return results
if not self.testing:
self.save_spawn_weights(model)

def save_spawn_weights(self, model):
"""
Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only
import warnings

# warnings to ignore
# warnings to ignore in trainer
warnings.filterwarnings('ignore', message='torch.distributed.reduce_op is deprecated, '
'please use torch.distributed.ReduceOp instead')

Expand Down Expand Up @@ -1043,9 +1043,14 @@ def __run_ddp_spawn(self, model, nprocs):
# restore main state with best weights
best_path = q.get()
results = q.get()

# transfer back the best path to the trainer
if best_path is not None and len(best_path) > 0:
self.checkpoint_callback.best_model_path = best_path
model.load_from_checkpoint(best_path)

# load last model weights
if self.testing:
self.load_spawn_weights(model)

self.model = model
return results
Expand Down