Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorMetric not updated to cuda device #2274

Closed
tayden opened this issue Jun 19, 2020 · 10 comments
Closed

TensorMetric not updated to cuda device #2274

tayden opened this issue Jun 19, 2020 · 10 comments
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task

Comments

@tayden
Copy link

tayden commented Jun 19, 2020

🐛 Bug

When tensors located on the GPU are passed through a TensorMetric forward method, they are transfered to the CPU by this method. It seems that the 'cpu' device is not updated when training on a gpu. The cpu device seems to be hardcoded here

To Reproduce

Steps to reproduce the behavior:

import torch
import torch.nn.functional as F
from pytorch_lightning.metrics import TensorMetric


class FocalTverskyMetric(TensorMetric):
    def __init__(self, alpha: float, beta: float, gamma: float, smooth=1e-8):
        super().__init__("FocalTversky")
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.smooth = smooth

        # This line works as a workaround
        # self._device = torch.device('cuda')

    def _tversky_index_c(self, p: torch.Tensor, g: torch.Tensor):
        c = p.shape[1]
        p = p.permute(0, 2, 3, 1).reshape((-1, c))
        g = F.one_hot(g.flatten().long(), c)

        tp = torch.sum(torch.mul(p, g), dim=0)
        fn = torch.sum(torch.mul(1. - p, g), dim=0)
        fp = torch.sum(torch.mul(p, 1. - g), dim=0)
        return (tp + self.smooth) / (tp + self.alpha * fn + self.beta * fp + self.smooth)

    def forward(self, x, y):
        ti = self._tversky_index_c(x, y)
        res = (1 - ti).pow(1 / self.gamma)
        return torch.sum(res, dim=0)


if __name__ == '__main__':
    metric = FocalTverskyMetric(alpha=0.5, beta=0.5, gamma=1.)

    preds = torch.Tensor([[[[1.]], [[0.]]], [[[1.]], [[0.]]], [[[0.]], [[1.]]]]).cuda()
    assert preds.is_cuda  # Passes
    labels = torch.Tensor([[[1]], [[1]], [[0]]]).cuda()
    assert labels.is_cuda  # Passes

    loss = metric(preds, labels)
    assert loss.is_cuda  # Fails

When training in DDP mode with 16-bit precision, these metrics throws the stack trace below. This disappears in 32-bit mode since it's the amp grad_scaler asserting that the loss value is a cuda tensor.

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0,1,2,3]
Using 16bit precision.
Using 16bit precision.
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/4
Using 16bit precision.
initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/4
Using 16bit precision.
initializing ddp: GLOBAL_RANK: 3, MEMBER: 4/4
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/4
----------------------------------------------------------------------------------------------------
distributed_backend=ddp
All DDP processes registered. Starting ddp with 4 processes
----------------------------------------------------------------------------------------------------

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | DeepLabV3          | 60 M
1 | calc_loss | FocalTverskyMetric | 0
2 | calc_iou  | IoU                | 0
Epoch 1:   0% 0/493 [00:00<?, ?it/s] Traceback (most recent call last):
  File "/opt/code/deeplabv3_lightning.py", line 269, in <module>
    'pred': predict
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 138, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 468, in _Fire
    target=component.__name__)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 672, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/opt/code/deeplabv3_lightning.py", line 224, in train
    trainer.fit(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 895, in fit
    self.ddp_train(task, model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 526, in ddp_train
    self.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1091, in run_pretrain_routine
    self.train()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 374, in train
    self.run_training_epoch()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 457, in run_training_epoch
    _outputs = self.run_training_batch(batch, batch_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 633, in run_training_batch
    loss, batch_output = optimizer_closure()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 611, in optimizer_closure
    model_ref.backward(self, closure_loss, optimizer, opt_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/hooks.py", line 179, in backward
    self.trainer.scaler.scale(loss).backward()
  File "/opt/conda/lib/python3.7/site-packages/torch/cuda/amp/grad_scaler.py", line 156, in scale
    assert outputs.is_cuda
AssertionError
Traceback (most recent call last):
  File "/opt/code/deeplabv3_lightning.py", line 269, in <module>
    'pred': predict
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 138, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 468, in _Fire
    target=component.__name__)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 672, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/opt/code/deeplabv3_lightning.py", line 224, in train
    trainer.fit(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 895, in fit
    self.ddp_train(task, model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 526, in ddp_train
    self.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1091, in run_pretrain_routine
    self.train()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 374, in train
    self.run_training_epoch()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 457, in run_training_epoch
    _outputs = self.run_training_batch(batch, batch_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 633, in run_training_batch
    loss, batch_output = optimizer_closure()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 611, in optimizer_closure
    model_ref.backward(self, closure_loss, optimizer, opt_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/hooks.py", line 179, in backward
    self.trainer.scaler.scale(loss).backward()
  File "/opt/conda/lib/python3.7/site-packages/torch/cuda/amp/grad_scaler.py", line 156, in scale
    assert outputs.is_cuda
AssertionError
Traceback (most recent call last):
  File "deeplabv3_lightning.py", line 269, in <module>
    'pred': predict
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 138, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 468, in _Fire
    target=component.__name__)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 672, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "deeplabv3_lightning.py", line 224, in train
    trainer.fit(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 910, in fit
    self.spawn_ddp_children(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 442, in spawn_ddp_children
    self.ddp_train(local_rank, model, is_master=True)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 526, in ddp_train
    self.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1091, in run_pretrain_routine
    self.train()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 374, in train
    self.run_training_epoch()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 457, in run_training_epoch
    _outputs = self.run_training_batch(batch, batch_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 633, in run_training_batch
    loss, batch_output = optimizer_closure()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 611, in optimizer_closure
    model_ref.backward(self, closure_loss, optimizer, opt_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/hooks.py", line 179, in backward
    self.trainer.scaler.scale(loss).backward()
  File "/opt/conda/lib/python3.7/site-packages/torch/cuda/amp/grad_scaler.py", line 156, in scale
    assert outputs.is_cuda
AssertionError
Traceback (most recent call last):
  File "/opt/code/deeplabv3_lightning.py", line 269, in <module>
    'pred': predict
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 138, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 468, in _Fire
    target=component.__name__)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 672, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/opt/code/deeplabv3_lightning.py", line 224, in train
    trainer.fit(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 895, in fit
    self.ddp_train(task, model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 526, in ddp_train
    self.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1091, in run_pretrain_routine
    self.train()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 374, in train
    self.run_training_epoch()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 457, in run_training_epoch
    _outputs = self.run_training_batch(batch, batch_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 633, in run_training_batch
    loss, batch_output = optimizer_closure()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 611, in optimizer_closure
    model_ref.backward(self, closure_loss, optimizer, opt_idx)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/hooks.py", line 179, in backward
    self.trainer.scaler.scale(loss).backward()
  File "/opt/conda/lib/python3.7/site-packages/torch/cuda/amp/grad_scaler.py", line 156, in scale
    assert outputs.is_cuda
AssertionError
Exception ignored in: <function tqdm.__del__ at 0x7ff26f244a70>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1086, in __del__
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1293, in close
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1471, in display
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1089, in __repr__
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1433, in format_dict
TypeError: cannot unpack non-iterable NoneType object

Expected behavior

The TensorMetric should have self._device updated to equal the current Trainer device during initialization.

Environment

  • CUDA:
    • GPU:
      • GeForce GTX 1050
    • available: True
    • version: 10.2
  • Packages:
    • numpy: 1.18.5
    • pyTorch_debug: False
    • pyTorch_version: 1.6.0.dev20200618
    • pytorch-lightning: 0.8.0
    • tensorboard: 2.1.1
    • tqdm: 4.46.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.3
    • version: Unable to import trainer #41-Ubuntu SMP Wed Jun 3 18:57:02 UTC 2020

The bug is also present on AWS p3.8xlarge instances using the same environment but with 4 Nvidia Tesla V100s

@tayden tayden added bug Something isn't working help wanted Open to be worked on labels Jun 19, 2020
@tayden

This comment has been minimized.

@tayden tayden closed this as completed Jun 19, 2020
@tayden
Copy link
Author

tayden commented Jun 19, 2020

Nevermind, issue still persists in 0.8.1 release.

@tayden tayden reopened this Jun 19, 2020
@tayden
Copy link
Author

tayden commented Jun 19, 2020

I understand now that the metrics must be moved to the appropriate device, but it's not clear what the best way to do this is or if it should be handled automatically by the LightningModule somehow.

Currently I'm just setting self.calc_loss = FocalTverskyLoss(0.7, 0.3, 4./3.) in my LightningModule's __init__ function. Is there a way to register these metrics so that they'll be placed on the correct device when training with DDP etc.?

@tayden
Copy link
Author

tayden commented Jun 19, 2020

Can metrics not be used for the loss computation?

@williamFalcon
Copy link
Contributor

they should be ablet to...
@justusschock @SkafteNicki

@justusschock
Copy link
Member

justusschock commented Jun 21, 2020

This shall be possible after you call super().__init__() , since after that assignments of nn.Modules (which is the baseclass for Metrics) should be registered automatically.

This means. Once you add it as an Attribute to your lightning module, this should be updated automatically.

For completeness: this happens, because we automatically move the result to self._device and if your metric hasn't been moved to the correct one, this defaults to cpu

@tayden
Copy link
Author

tayden commented Jun 21, 2020

Based on the stack-trace above though, this movement off of the cpu and onto the cuda device doesn't seem to be happening. I'll try to post a minimal example using a metric as an attribute of a lightning module soon. The stacktrace is from a module though that uses the above TverskyLossMetric code, which also calls super().__init().

@justusschock
Copy link
Member

justusschock commented Jun 22, 2020

>>> from pytorch_lightning.metrics import Accuracy
>>> from pytorch_lightning import LightningModule
>>> class DummyModule(LightningModule):
...     def __init__(self):
...         super().__init__()
...         self.acc = Accuracy()
...     def forward(self, x, y):
...         return self.acc(x, y)
...
>>> dummy = DummyModule()
>>> dummy.device
device(type='cpu')
>>> dummy.acc.device
device(type='cpu')
>>> dummy.to('cuda')
DummyModule(
  (acc): Accuracy()
)
>>> dummy.device
device(type='cuda')
>>> dummy.acc.device
device(type='cpu')

Unfortunately I could also reproduce this behaviour for LightningModules. So with nested modules only the device (and type) of the main model is updated, but not on the inner models.

So this is rather not specific to metrics, but a general issue with our mixing here.

@SkafteNicki
Copy link
Member

I guess this can be fixed by introducing a buffer to each metric, which is a feature we need to implement regardless of this problem if we want to accumulate stats over multiple batches. For now, simply changing the base class to something like:

class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC):
    """
    Abstract base class for metric implementation.

    Should be used to implement metrics that
    1. Return multiple Outputs
    2. Handle their own DDP sync
    """

    def __init__(self, name: str):
        """
        Args:
            name: the metric's name

        """
        super().__init__()
        self.name = name
        self._dtype = torch.get_default_dtype()
        self._device = torch.device('cpu')
        
        self.register_buffer('state', torch.zeros(1))
        self.reset_state()
        
    @abstractmethod
    def forward(self, *args, **kwargs) -> torch.Tensor:
        """
        Implements the actual metric computation.

        Returns:
            metric value

        """
        raise NotImplementedError
        
    def reset_state(self):
        self.state.zeros_()

    @property
    def device(self):
        """ From https://github.com/pytorch/pytorch/issues/7460#issuecomment-623200989 """
        devices = ({param.device for param in self.parameters()} |
                   {buf.device for buf in self.buffers()})
        if len(devices) != 1:
            raise RuntimeError('Cannot determine device: {} different devices found'
                               .format(len(devices)))
        return next(iter(devices))

Should work.

@awaelchli
Copy link
Member

awaelchli commented Aug 3, 2020

I think I fixed this recently in #2657

The cpu device seems to be hardcoded here

@tayden It just looks like that, but this is just the initial value there. The DeviceDtypeMixin takes care of updating the device.
In your example here you need to move the metric to cuda as well.

metric = FocalTverskyMetric(alpha=0.5, beta=0.5, gamma=1.).cuda()

But of course, if you assign the metric as property in LightningModule, then the metric will also be moved automatically to the device the LightningModule is on.

I agree with @SkafteNicki. All metrics that save tensors must register them as buffers.

@justusschock justusschock added this to To do in Metrics package via automation Sep 1, 2020
Metrics package automation moved this from To do to Done Sep 1, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
No open projects
Development

No branches or pull requests

5 participants