Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkpoint saving stuck when use multiple GPUs #6495

Closed
sun-peach opened this issue Mar 12, 2021 · 14 comments
Closed

checkpoint saving stuck when use multiple GPUs #6495

sun-peach opened this issue Mar 12, 2021 · 14 comments
Labels
bug Something isn't working checkpointing Related to checkpointing distributed Generic distributed-related topic help wanted Open to be worked on priority: 0 High priority task waiting on author Waiting on user action, correction, or update
Milestone

Comments

@sun-peach
Copy link

sun-peach commented Mar 12, 2021

🐛 Bug

When I use multiple GPUs, the model saving step will be stuck, while it works perfectly when I use only one GPU.

Please reproduce using the BoringModel

class Spectrogram_based(Conditional_Source_Separation, metaclass=ABCMeta):

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)

        parser.add_argument('--n_fft', type=int, default=2048)
        parser.add_argument('--hop_length', type=int, default=1024)
        parser.add_argument('--num_frame', type=int, default=128)
        parser.add_argument('--spec_type', type=str, default='complex')
        parser.add_argument('--spec_est_mode', type=str, default='mapping')

        parser.add_argument('--train_loss', type=str, default='spec_mse')
        parser.add_argument('--val_loss', type=str, default='raw_l1')
        parser.add_argument('--unfreeze_stft_from', type=int, default=-1)  # -1 means never.

        return Conditional_Source_Separation.add_model_specific_args(parser)

    def __init__(self, n_fft, hop_length, num_frame,
                 spec_type, spec_est_mode,
                 conditional_spec2spec,
                 optimizer, lr,
                 train_loss, val_loss, hparams=None
                 ):
        super(Spectrogram_based, self).__init__(n_fft, hop_length, num_frame,
                                                optimizer, lr, hparams)

        self.n_fft = n_fft
        self.hop_length = hop_length
        self.num_frame = num_frame

        assert spec_type in ['magnitude', 'complex']
        assert spec_est_mode in ['masking', 'mapping']
        self.magnitude_based = spec_type == 'magnitude'
        self.masking_based = spec_est_mode == 'masking'
        self.stft = fourier.multi_channeled_STFT(n_fft=n_fft, hop_length=hop_length)
        self.stft.freeze()

        self.spec2spec = conditional_spec2spec
        self.valid_estimation_dict = {}
        self.val_loss = val_loss
        self.train_loss = train_loss

        self.init_weights()

    def init_weights(self):
        init_weights_functional(self.spec2spec,
                                self.spec2spec.activation)

    def training_step(self, batch, batch_idx):
        mixture_signal, target_signal, condition = batch
        loss = self.train_loss(self, mixture_signal, condition, target_signal)
        self.log('train_loss', loss, prog_bar=False, logger=True, on_step=False, on_epoch=True,
                 reduce_fx=torch.mean)
        return loss

    # Validation Process
    def on_validation_epoch_start(self):
        for target_name in self.target_names:
            self.valid_estimation_dict[target_name] = {mixture_idx: {}
                                                       for mixture_idx
                                                       in range(14)}

    def validation_step(self, batch, batch_idx):

        mixtures, targets, mixture_ids, window_offsets, input_conditions, target_names = batch

        loss = self.val_loss(self, mixtures, input_conditions, targets)

        self.log('raw_val_loss', loss, prog_bar=False, logger=False, reduce_fx=torch.mean)

        # Result Cache
        if 0 in mixture_ids.view(-1):
            estimated_targets = self.separate(mixtures, input_conditions)[:, self.trim_length:-self.trim_length]
            targets = targets[:, self.trim_length:-self.trim_length]

            for mixture, mixture_idx, window_offset, input_condition, target_name, estimated_target \
                    in zip(mixtures, mixture_ids, window_offsets, input_conditions, target_names, estimated_targets):

                if mixture_idx == 0:
                    self.valid_estimation_dict[target_name][mixture_idx.item()][
                        window_offset.item()] = estimated_target.detach().cpu().numpy()
        return loss

    def validation_epoch_end(self, outputs: List[Any]) -> None:
        for idx in [0]:
            estimation = {}
            for target_name in self.target_names:
                estimation[target_name] = get_estimation(idx, target_name, self.valid_estimation_dict)
                if estimation[target_name] is None:
                    continue
                if estimation[target_name] is not None:
                    estimation[target_name] = estimation[target_name].astype(np.float32)

                    if self.current_epoch > 1 and isinstance(self.logger, WandbLogger):
                        track = estimation[target_name]
                        if track.shape[0] > 40 * 44100:
                            track = track[44100 * 20:44100 * 40]

                        self.logger.experiment.log({'result_sample_{}_{}'.format(self.current_epoch, target_name): [
                            wandb.Audio(track, caption='{}_{}'.format(idx, target_name), sample_rate=44100)]})

        reduced_loss = torch.stack(outputs).mean()
        self.log('val_loss', reduced_loss, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
        print(reduced_loss)

    @abstractmethod
    def to_spec(self, input_signal) -> torch.Tensor:
        pass

    @abstractmethod
    def separate(self, input_signal, input_condition) -> torch.Tensor:
        pass

To Reproduce

My checkpoint_callback is

    checkpoint_callback = ModelCheckpoint(
        dirpath=ckpt_path,
        save_top_k=save_top_k,
        verbose=True,
        monitor = "val_loss",
        save_last= False,
        period  = args["check_val_every_n_epoch"],
        save_weights_only=args['save_weights_only']
    )

I use the command python main.py --gpus 4 --distributed_backend ddp for multiple-GPU running, while I use python main.py --gpus 1 for single GPU running. I did not change anything else.

Expected behavior

Model supposed to be saved smoothly, however, it is stuck at the step of saving the checkpoint. The GPU utilization shows 100% and never change. Please see the figures below:
GPU utilization stay at 100% forever
image

Saving is stuck at epoch 0
image

Environment

  • pytorch lightning version :1.2.1
  • PyTorch Version (e.g., 1.0): 1.4
  • OS (e.g., Linux): Ubuntu 18.04
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version: 10.2
  • GPU models and configuration:
  • Any other relevant information:
@sun-peach sun-peach added bug Something isn't working help wanted Open to be worked on labels Mar 12, 2021
@awaelchli
Copy link
Contributor

it is possible that this PR will solve it: #6410

@awaelchli awaelchli added checkpointing Related to checkpointing distributed Generic distributed-related topic labels Mar 14, 2021
@azouaoui-cv
Copy link

This is a huge pain for me at the moment too. The model checkpointing routine is stuck at the end of epoch X where X is random (e.g. 5,10 or whatever).
@awaelchli is there any cheap workaround or do we have to wait until the PR you mentioned gets merged?
I would be fine with the second option if it was not for an upcoming deadline set on Wednesday.

Best,
Alex

@awaelchli
Copy link
Contributor

I just saw the issue and the other PR and it's just a guess. I can't say for sure that it fixes this problem, also because the code sample you provided does not run directly.
Also, we don't know for sure that the problem here is the saving, all we see is that the last printed message is related to the checkpoint. The GPU utilization at 100% is another indication that the problem might not be with checkpointing.

One thing I noticed however is this suspicious list of arguments:

    checkpoint_callback = ModelCheckpoint(
        dirpath=ckpt_path,
        save_top_k=save_top_k,
        verbose=True,
        monitor = "val_loss",
        save_last= False,
        period  = args["check_val_every_n_epoch"],
        save_weights_only=args['save_weights_only']
    )

This doesn't look right. The checkpoint does not accept an argument save_weights_only and if you set check_val_every_n_epoch=x in Trainer then probably period in model checkpoint should not be set.

I suggest these steps:

  • remove save_weights_only argument from ModelCheckpoint
  • remove period and set check_val_every_n_epoch in the Trainer instead
  • upgrade to master pip install git+https://github.com/PyTorchLightning/pytorch-lightning@master

@azouaoui-cv
Copy link

azouaoui-cv commented Mar 14, 2021

I tried running my experiments from pytorch-lightning@bugfix/broadcast_2 but the issue remains, although it seems less prevalent (i.e. 11/14 jobs are still running 4 hours into the training).
Even though this is mainly my responsibility, I have so far burnt close to 3000 GPU hours on a limited computational budget only to realize my jobs had silently failed. Needless to say that it's slightly annoying.
Is there anything I can do to help fix this bug?

EDIT: I can, for instance, provide a minimum working example tailored to my use case. Note that I just mainly tweaked the official ImageNet PL example.

@awaelchli awaelchli added the priority: 0 High priority task label Mar 14, 2021
@awaelchli
Copy link
Contributor

Is there anything I can do to help fix this bug?

Sorry for the standard answer but a minimal working example (ready to run) would be the best, because then we can directly start debugging with minimal guess work. I understand that given the conference deadline this is probably too much work (I myself am also submitting to ICCV next week).

@azouaoui-cv
Copy link

I will see what I can do. My only concern for reproducibility is that the issue seems quite random so far (not all runs are impacted, it happens at a random epoch). I'm not sure I would be able to reproduce it given a simple BoringModel for example. But yes, I could give it a try!

@sun-peach
Copy link
Author

@awaelchli Thanks. I can remove the save_weights_only. The logic of setting period equal to the number args["check_val_every_n_epoch"] is that I would like to do validation check when I save the model. Based on what you said, when I set the check_val_every_n_epoch in Trainer, I don't need this in the callback? It will automatically do the check periodically ?

Thank you once again.

@sun-peach
Copy link
Author

@awaelchli @inzouzouwetrust Hi, just want to update that when I switch to "dp" backend, everything is OK. Hope this help you identify the problem.

Thanks.

@tchaton
Copy link
Contributor

tchaton commented Mar 21, 2021

Dear @sun-peach,

Any chance you could provide a reproducible script ?

Would you mind trying to following:

  • Master: We recently merged a PR about DDP hanging.
  • Add sync_dist=True within all your self.log, just to see if the pb remains.

Best,
T.C

@yukw777
Copy link
Contributor

yukw777 commented Mar 22, 2021

@tchaton I see this issue with pytorch 1.7 and PL 1.2.4, which has those fixes that are supposed to fix DDP hanging. sync_dist=True doesn't help either (it's actually set to true in the BoringModel code). My script doesn't even finish the validation loop. The gpu usage goes to 100% at some point and the training script simply hangs.

@tchaton
Copy link
Contributor

tchaton commented Apr 6, 2021

Dear @azouaoui-cv, @yukw777,

Would it be possible for you to share a reproducible script for us to work on ?

Best,
T.C

@edenlightning edenlightning added the waiting on author Waiting on user action, correction, or update label Apr 6, 2021
@yukw777
Copy link
Contributor

yukw777 commented Apr 8, 2021

@tchaton I no longer have access to a multi-gpu machine. I'd like to try reproducing it once I get a hold of one. :/ @azouaoui-cv did the issue ever go away?

@flukeskywalker
Copy link

I've had exactly the same problem, and found that the problem is caused specifically if (using ddp on multiple GPUs) checkpoint saving is done based on monitoring a quantity.

@sun-peach do the hangups go away if you remove the monitor-based saving and simply checkpoint at every epoch?

@edenlightning edenlightning added this to the v1.3 milestone Apr 27, 2021
@Borda Borda modified the milestones: v1.3, v1.3.x May 6, 2021
@edenlightning
Copy link
Contributor

Closing for now, please reopen with a reproducible script.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing distributed Generic distributed-related topic help wanted Open to be worked on priority: 0 High priority task waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

No branches or pull requests

8 participants