Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validation_step_end and training_step_end usage #2435

Closed
pamparana34 opened this issue Jun 30, 2020 · 5 comments
Closed

validation_step_end and training_step_end usage #2435

pamparana34 opened this issue Jun 30, 2020 · 5 comments
Labels
question Further information is requested

Comments

@pamparana34
Copy link

I cannot seem to find any examples on how to collect all the batches from the validation and training steps when using ddp. I currently am doing training on 4 GPUs using ddp and my validation loop is as follows:

    def validation_step(self, batch, batch_idx):
        anchor, positive, negatives = batch
        negatives = negatives.transpose(0, 1)
        losses = []
        for i in range(len(negatives)):
            anchor_out, positive_out, negative_out = self.forward_train(anchor,
                                                                        positive,
                                                                        negatives[i])
            loss_val = self.lossfn(anchor_out, positive_out, negative_out)
            losses.append(loss_val)

        loss_val = torch.stack(losses).mean()
        return {'val_loss': loss_val}

    def validation_epoch_end(self, outputs):
        loss_val = torch.stack([x['val_loss'] for x in outputs]).mean()
        log_dict = {'validation_loss': loss_val, 'step': self.current_epoch}

        print('Mean Val loss: ', loss_val.item())
        return {'log': log_dict, 'val_loss': log_dict['validation_loss'], 'progress_bar': log_dict}

At the moment, when computing the val_loss, it is only taking one of the processes into account and my statistics is not over the whole validation dataset (I think this is the same for my training set) and I would like it to be over the whole dataset. To that affect, I need to gather all the outputs from all the GPUs.

I see that there are some validation_step_end and training_step_end callbacks but I do not see much examples or usage of them? Could someone please comment on whether this can be used for doing what I am trying to do i.e. compute my training loss and validation loss over the whole dataset when reporting? A small example would be really useful for newbies like me.

For completeness my training loop is as follows:

    def training_step(self, batch, batch_idx):
        anchor, positive, negative = batch
        anchor_out, positive_out, negative_out = self.forward_train(anchor,
                                                                    positive,
                                                                    negative)

        loss_val = self.lossfn(anchor_out, positive_out, negative_out)
        return {'loss': loss_val}

    def training_epoch_end(self, outputs):
        loss_val = torch.stack([x['loss'] for x in outputs]).mean()
        log_dict = {'training_loss': loss_val, 'step': self.current_epoch}
        return {'log': log_dict}

I see there are several versions of this same question here. So, I think it would really help to have a small example of how to do this.

@pamparana34 pamparana34 added the question Further information is requested label Jun 30, 2020
@pamparana34 pamparana34 changed the title validation_step_end usage validation_step_end and training_step_end usage Jun 30, 2020
@junwen-austin
Copy link

junwen-austin commented Aug 17, 2020

@pamparana34 I believe even with validation_step_end in version 0.8.5, you still cannot get the metrics over the entire dataset. What you can get with validation_step_end is the metrics over the one complete batch (one complete batch is sum of batches on all GPU at a given time point). See recent comments in #973

@awaelchli
Copy link
Contributor

@pamparana34 @junwen-austin
You can now do this:

def validation_step(self, batch, batch_idx):
        anchor, positive, negatives = batch
        negatives = negatives.transpose(0, 1)
        losses = []
        for i in range(len(negatives)):
            anchor_out, positive_out, negative_out = self.forward_train(anchor,
                                                                        positive,
                                                                        negatives[i])
            loss_val = self.lossfn(anchor_out, positive_out, negative_out)
            losses.append(loss_val)

        loss_val = torch.stack(losses).mean()
        result = EvalResult()
        result.log("val_loss", val_loss, sync_dist=True)  # sync_dist will compute mean over all processes
        return result

and no need for validation_epoch_end. The results object can average your val_losses across the dist group. The same can be done for your training loop. Let me know if that works for you.

@awaelchli
Copy link
Contributor

Closing this, I am confident my answer applies to your use case. But if something does not work, let me know.

@awaelchli
Copy link
Contributor

awaelchli commented Sep 20, 2020

@MaveriQ
Copy link

MaveriQ commented Feb 28, 2022

I am not sure if I should start a new issue or continue this one. My query is about usage of training_step_end, so I hope this is an appropriate place (let me know if I should instead have a separate issue).

My use-case involves collecting outputs from multiple batches (say 5) before I calculate the loss. So I am hoping I can use training_step to collect outputs from those 5 batches and use training_step_end to calculate the loss and back-propagate. My understanding is that since I calculate loss only once after every 5th batch, I can't use gradient accumulation.

What I am not sure about is if I can use training_step without returning the loss and instead return the loss only after training_step_end. Thank you for your help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants