Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer DDP should invoke load_spawn_weights() only in proc_rank == 0 #1335

Closed
areshytko opened this issue Apr 1, 2020 · 0 comments · Fixed by #1385
Closed

Trainer DDP should invoke load_spawn_weights() only in proc_rank == 0 #1335

areshytko opened this issue Apr 1, 2020 · 0 comments · Fixed by #1385
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@areshytko
Copy link
Contributor

areshytko commented Apr 1, 2020

🐛 Bug

Trainer DDP load_spawn_weights should happen only in proc_rank == 0 since only in this process (node) save_spawn_weights actually saves checkpoint

To Reproduce

Steps to reproduce the behavior:

  1. setup two-node cluster.
  2. set SLURM_NODEID on each node: '0' on node 0 and '1' on node 1.
  3. run the script python app.py on each node.
  4. see stdout on the node 1:
Traceback (most recent call last):
  File "app.py", line 166, in <module>
    main_()  # pylint: disable=no-value-for-parameter
  File "app.py", line 162, in main_
    trainer.fit(model)
  File "/home/ubuntu/anaconda3/envs/nightly_pt/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 593, in fit
    self.load_spawn_weights(model)
  File "/home/ubuntu/anaconda3/envs/nightly_pt/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 368, in load_spawn_weights
    loaded_model = original_model.__class__.load_from_checkpoint(path)
  File "/home/ubuntu/anaconda3/envs/nightly_pt/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py", line 1353, in load_from_checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
  File "/home/ubuntu/anaconda3/envs/nightly_pt/lib/python3.7/site-packages/torch/serialization.py", line 525, in load
    with _open_file_like(f, 'rb') as opened_file:
  File "/home/ubuntu/anaconda3/envs/nightly_pt/lib/python3.7/site-packages/torch/serialization.py", line 212, in _open_file_like
    return _open_file(name_or_buffer, mode)
  File "/home/ubuntu/anaconda3/envs/nightly_pt/lib/python3.7/site-packages/torch/serialization.py", line 193, in __init__
    super(_open_file, self).__init__(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: '/home/ubuntu/pytorch-lightning-intro-guide/__temp_weight_ddp_end.ckpt'

Code sample

app.py:

import pathlib

import pytorch_lightning as pl
import torch
from torch.nn import functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms


class LitMNIST(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)
        x = F.log_softmax(x, dim=1)
        return x

    def prepare_data(self):
        # transform
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        # download
        data_dir = pathlib.Path.home() / 'data'
        mnist_train = datasets.MNIST(data_dir, train=True,
                                     download=True, transform=transform)
        mnist_test = datasets.MNIST(data_dir, train=False,
                                    download=True, transform=transform)

        # train/val split
        mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

        # assign to use in dataloaders
        self.train_dataset = mnist_train
        self.val_dataset = mnist_val
        self.test_dataset = mnist_test

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=64)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=64)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=64)

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # add logging
        logs = {'loss': loss}
        return {'loss': loss, 'log': logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return {'val_loss': loss}

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return {'val_loss': loss}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack(  # pylint: disable=no-member
            [x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    def init_ddp_connection(self, proc_rank: int, world_size: int) -> None:
        torch.distributed.init_process_group(
            'nccl', rank=proc_rank, world_size=world_size)

def main():
    model = LitMNIST()

    gpus = 1
    num_nodes = 2

    trainer = pl.Trainer(gpus=gpus,
                         num_nodes=num_nodes,
                         distributed_backend='ddp',
                         max_epochs=3)
    trainer.fit(model)


if __name__ == '__main__':
    main()

Expected behavior

All workers on all nodes should finish without errors.

Environment

On each node:

cuda:
	GPU:
		Tesla K80
		Tesla K80
		Tesla K80
		Tesla K80
		Tesla K80
		Tesla K80
		Tesla K80
		Tesla K80
	available:           True
	version:             10.1
packages:
	numpy:               1.16.6
	pyTorch_debug:       False
	pyTorch_version:     1.4.0
	pytorch-lightning:   0.7.1
	tensorboard:         2.2.0
	tqdm:                4.44.1
system:
	OS:                  Linux
	architecture:
		64bit

	processor:           x86_64
	python:              3.7.7
	version:             #113-Ubuntu SMP Wed Jan 29 14:54:54 UTC 2020

Additional context

@areshytko areshytko added bug Something isn't working help wanted Open to be worked on labels Apr 1, 2020
@areshytko areshytko changed the title Trainer DDP load_spawn_weights should happen only in proc_rank == 0 Trainer DDP should invoke load_spawn_weights() only in proc_rank == 0 Apr 1, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant