Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When using dp mode, only torch.Tensor can be used as the return value of the *_step function. #1904

Closed
shimacos37 opened this issue May 20, 2020 · 3 comments
Labels
help wanted Open to be worked on won't fix This will not be worked on

Comments

@shimacos37
Copy link

shimacos37 commented May 20, 2020

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. Go to ./pl_examples/basic_examples
  2. Run python gpu_template.py --gpus 2 --distributed_backend dp
  3. See error

Code sample

Error is below.

Validation sanity check: 0it [00:00, ?it/s]Traceback (most recent call last):
  File "/root/workdir/pytorch-lightning/pl_examples/basic_examples/gpu_template.py", line 80, in <module>
    main(hyperparams)
  File "/root/workdir/pytorch-lightning/pl_examples/basic_examples/gpu_template.py", line 41, in main
    trainer.fit(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 853, in fit
    self.dp_train(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 578, in dp_train
    self.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1001, in run_pretrain_routine
    False)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 277, in _evaluate
    output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 424, in evaluation_forward
    output = model(*args)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/overrides/data_parallel.py", line 66, in forward
    return self.gather(outputs, self.output_device)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 165, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
    res = gather_map(outputs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 62, in gather_map
    for k in out))
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 62, in <genexpr>
    for k in out))
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
TypeError: zip argument #1 must support iteration

This error has something to do with this code (https://github.com/pytorch/pytorch/blob/f4f0dd470c7eb51511194a52e87f0ceec5d4e05e/torch/nn/parallel/scatter_gather.py#L47).
And this error can be fixed by doing the following in./pl_examples/models/lightning_template.py

def validation_step(self, batch, batch_idx):
    """
    Lightning calls this inside the validation loop with the data from the validation dataloader
    passed in as `batch`.
    """
    x, y = batch
    y_hat = self(x)
    val_loss = F.cross_entropy(y_hat, y)
    labels_hat = torch.argmax(y_hat, dim=1)
    n_correct_pred = torch.sum(y == labels_hat).item()
    return {'val_loss': val_loss, "n_correct_pred": n_correct_pred, "n_pred": len(x)}

def validation_epoch_end(self, outputs):
    """
    Called at the end of validation to aggregate outputs.
    :param outputs: list of individual outputs of each validation step.
    """
    avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
    val_acc = sum([x['n_correct_pred'] for x in outputs]) / sum(x['n_pred'] for x in outputs)
    tensorboard_logs = {'val_loss': avg_loss, 'val_acc': val_acc}
    return {'val_loss': avg_loss, 'log': tensorboard_logs}

to

def validation_step(self, batch, batch_idx):
    """
    Lightning calls this inside the validation loop with the data from the validation dataloader
    passed in as `batch`.
    """
    x, y = batch
    y_hat = self(x)
    val_loss = F.cross_entropy(y_hat, y)
    labels_hat = torch.argmax(y_hat, dim=1)
    n_correct_pred = torch.sum(y == labels_hat)
    return {
        "val_loss": val_loss,
        "n_correct_pred": n_correct_pred,
        "n_pred": torch.tensor(len(x)).to(val_loss.device),
    }

def validation_epoch_end(self, outputs):
    """
    Called at the end of validation to aggregate outputs.
    :param outputs: list of individual outputs of each validation step.
    """
    avg_loss = (
        torch.stack([x["val_loss"].detach().cpu() for x in outputs]).mean().item()
    )
    val_acc = np.sum(
        [x["n_correct_pred"].detach().cpu().numpy() for x in outputs]
    ) / np.sum([x["n_pred"].detach().cpu().numpy() for x in outputs])
    tensorboard_logs = {"val_loss": avg_loss, "val_acc": val_acc}
    print({"val_loss": avg_loss, "log": tensorboard_logs})
    return {"val_loss": avg_loss, "log": tensorboard_logs}

But this approach is not elegant ...

Expected behavior

  • Return values other than torch.Tensor are allowed.

Environment

  • PyTorch Version : 1.5
  • OS (e.g., Linux): Ubuntu 18.04
  • How you installed PyTorch conda
  • Python version: 3.7.7
  • CUDA/cuDNN version: 10.2

Additional context

@shimacos37 shimacos37 added the help wanted Open to be worked on label May 20, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@nsarang
Copy link
Contributor

nsarang commented May 24, 2020

I think this is the same as #1861.

@stale
Copy link

stale bot commented Jul 23, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Jul 23, 2020
@stale stale bot closed this as completed Aug 1, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Open to be worked on won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants