Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accuracy metric return tuple of length num_gpus on ddp in 0.9.1rc4 #3676

Closed
xiadingZ opened this issue Sep 27, 2020 · 7 comments · Fixed by #3517
Closed

Accuracy metric return tuple of length num_gpus on ddp in 0.9.1rc4 #3676

xiadingZ opened this issue Sep 27, 2020 · 7 comments · Fixed by #3517
Assignees
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@xiadingZ
Copy link

code:

        vid_acc = self.accuracy_video(video_labels_hat, video_labels)
        print(len(vid_acc), vid_acc)
        monitor = 0-vid_acc

In 0.9.1rc3, vid acc is a tensor, but in rc4, it changes to a tuple. I want to use -vid_acc as monitor, and I think it should be a tensor.
Using rc4, in macos's cpu mode, it's a tensor. But in linux ddp mode, it's a tuple of length num_gpus.

@xiadingZ xiadingZ added bug Something isn't working help wanted Open to be worked on labels Sep 27, 2020
@awaelchli
Copy link
Member

Thanks for the report. Apologies but I'm afraid I don't understand how this relates to our accuracy metric.
Looking at the code pytorch_lightning.metrics.classification.Accuracy I don't see a hint why this would return a tuple.
Could you share the code of your function accuracy_video?

cc @SkafteNicki @justusschock

@xiadingZ
Copy link
Author

This is my code:
0.9.1rc3:

from pytorch_lightning.metrics import Accuracy
self.accuracy_video = Accuracy(reduce_op='avg')

0.9.1rc4:

from pytorch_lightning.metrics import Accuracy
self.accuracy_video = Accuracy()

In 0.9.1rc4, reduce_op is changed to class_reduction, and I think it does what reduce_op='avg' do by befault

@awaelchli
Copy link
Member

awaelchli commented Sep 28, 2020

yes it does.
and I a cannot reproduce on the master branch what you describe in the issue.
When running with ddp and dpp_spawn on 2 gpus, I always get a single scalar accuarcy per gpu. So on each gpu, it's a scalar as expected.

@xiadingZ
Copy link
Author

I write a example to reproduce:

import os
import torch
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torch.utils.data import random_split
from pytorch_lightning.callbacks import LearningRateLogger
from pytorch_lightning.metrics import Accuracy
import argparse

class LitModel(pl.LightningModule):

    def __init__(self, hparams, *args, **kwargs):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, hparams.c)
        self.ac = Accuracy()

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss, )
        return result
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        label_hat = torch.argmax(y_hat, dim=1).type_as(y)
        acc = self.ac(label_hat, y)
        print(type(acc), acc)
        monitor = 0-acc
        result = pl.EvalResult(checkpoint_on=monitor)
        result.log('acc', acc, sync_dist=True, prog_bar=True)
        return result

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0005)
    
    def train_dataloader(self):
        dataset = MNIST('./', download=False, transform=transforms.ToTensor())
        train_loader = DataLoader(dataset, batch_size=128)
        return train_loader

    def val_dataloader(self):
        dataset = MNIST('./', download=False, transform=transforms.ToTensor())
        train_loader = DataLoader(dataset, batch_size=128)
        return train_loader
    
    
def main(hparams):
    model = LitModel(hparams)

    trainer = pl.Trainer(
        gpus=hparams.gpus,
        num_nodes=hparams.num_nodes,
        distributed_backend='ddp'
    )

    trainer.fit(model)


if __name__ == '__main__':
    # TRAIN
    parser = argparse.ArgumentParser(conflict_handler="resolve")
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument('--c', type=int, default=10)
    hparams = parser.parse_args()
    hparams.gpus=2
    hparams.num_nodes=2
    hparams.max_epochs=5
    main(hparams)

Here is error:

<class 'tuple'> (tensor([0.0312], device='cuda:1'), tensor([0.0781], device='cuda:1'), tensor([0.0938], device='cuda:1'), ten
sor([0.1406], device='cuda:1'))
<class 'tuple'> (tensor([0.0312], device='cuda:1'), tensor([0.0781], device='cuda:1'), tensor([0.0938], device='cuda:1'), tensor([0.1406], device='cuda:1'))
/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
<class 'tuple'> (tensor([0.0312], device='cuda:0'), tensor([0.0781], device='cuda:0'), tensor([0.0938], device='cuda:0'), tensor([0.1406], device='cuda:0'))
<class 'tuple'> (tensor([0.0312], device='cuda:0'), tensor([0.0781], device='cuda:0'), tensor([0.0938], device='cuda:0'), tensor([0.1406], device='cuda:0'))
Traceback (most recent call last):
  File "train.py", line 78, in <module>
    main(hparams)
  File "train.py", line 66, in main
    trainer.fit(model)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 302, in fit
    results = self.accelerator_backend.train()
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_backend.py", line 133, in train
    self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_base_backend.py", line 170, in ddp_train_tmp
    results = self.train_or_test()
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/base_backend.py", line 36, in train_or_test
    results = self.trainer.train()
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 324, in train
    self.run_sanity_check(self.get_model())
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 499, in run_sanity_check
    _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 430, in run_evaluation
    output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 145, in evaluation_step
    output = self.trainer.accelerator_backend.validation_step(args)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_base_backend.py", line 50, in validation_step
    output = self.training_step(args)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_base_backend.py", line 46, in training_step
    output = self.trainer.model(*args)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 558, in __call__
    result = self.forward(*input, **kwargs)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/overrides/data_parallel.py", line 182, in forward
    output = self.module.validation_step(*inputs[0], **kwargs[0])
  File "train.py", line 38, in validation_step
    monitor = 0-acc
TypeError: unsupported operand type(s) for -: 'int' and 'tuple'
Traceback (most recent call last):
  File "train.py", line 78, in <module>
    main(hparams)
  File "train.py", line 66, in main
    trainer.fit(model)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 302, in fit
    results = self.accelerator_backend.train()
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_backend.py", line 133, in train
    self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_base_backend.py", line 170, in ddp_train_tmp
    results = self.train_or_test()
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/base_backend.py", line 36, in train_or_test
    results = self.trainer.train()
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 324, in train
    self.ruTraceback (most recent call last):
  File "train.py", line 78, in <module>
    main(hparams)
  File "train.py", line 66, in main
    trainer.fit(model)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 302, in fit
    results = self.accelerator_backend.train()
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_backend.py", line 133, in train
    self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_base_backend.py", line 170, in ddp_train_tmp
    results = self.train_or_test()
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/base_backend.py", line 36, in train_or_test
    results = self.trainer.train()
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 324, in train
    self.run_sanity_check(self.get_model())
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 499, in run_sanity_check
    _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 430, in run_evaluation
    output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 145, in evaluation_step
    output = self.trainer.accelerator_backend.validation_step(args)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_base_backend.py", line 50, in validation_step
    output = self.training_step(args)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_base_backend.py", line 46, in trn_sanity_check(self.get_model())
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 499, in run_sanity_check
    _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 430, in run_evaluation
    output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 145, in evaluation_step
    output = self.trainer.accelerator_backend.validation_step(args)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_base_backend.py", line 50, in validation_step
    output = self.training_step(args)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_base_backend.py", line 46, in training_step
    output = self.trainer.model(*args)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 558, in __call__
    result = self.forward(*input, **kwargs)
  File "/mnt/lustre/maxiao1/anaconda3/lib/python3.7/site-packages/pytorch_lightning/overrides/data_parallel.py", line 182, in forward
    output = self.module.validation_step(*inputs[0], **kwargs[0])
  File "train.py", line 38, in validation_step
    monitor = 0-acc
TypeError: unsupported operand type(s) for -: 'int' and 'tuple'

@awaelchli
Copy link
Member

ok, can reproduce.
The solution is to use the functional:
pytorch_lightning.metrics.classification.accuracy
instad of the module.

not sure yet why the two behave differently

@justusschock
Copy link
Member

The issue is that functional accuracy directly calculates it and modular accuracy also does some syncing. But it definitely shouldn't be like this. I'll have a look at it.

@justusschock justusschock self-assigned this Sep 28, 2020
@SkafteNicki
Copy link
Member

This has to do with how aggregation currently is implemented in the class metrics is implemented.
It will be solved when #3517 is merged.

@awaelchli awaelchli linked a pull request Sep 28, 2020 that will close this issue
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants