Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.fit() returns last not best weights in ddp_spawn #2565

Merged
merged 13 commits into from
Jul 9, 2020
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class TrainerDDPMixin(ABC):
num_nodes: int
node_rank: int
tpu_cores: int
testing: bool

@property
@abstractmethod
Expand Down Expand Up @@ -559,9 +560,13 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
torch.cuda.empty_cache()

if self.global_rank == 0 and q is not None:
rank_zero_warn('cleaning up ddp environment...')
q.put(self.checkpoint_callback.best_model_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add a None check here for checkpoint_callback? because the user can set it to None, if they want.
See #2547

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed!

q.put(results)

if not self.testing:
self.save_spawn_weights(model)

if self.global_rank == 0 and self.distributed_backend != 'ddp_spawn':
return results

Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,9 +1043,14 @@ def __run_ddp_spawn(self, model, nprocs):
# restore main state with best weights
best_path = q.get()
results = q.get()

# transfer back the best path to the trainer
if best_path is not None and len(best_path) > 0:
self.checkpoint_callback.best_model_path = best_path
model.load_from_checkpoint(best_path)

# load last model weights
if self.testing:
self.load_spawn_weights(model)

self.model = model
return results
Expand Down