Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training_epoch_end log output gets combined with next epoch training #2455

Closed
ameliatqy opened this issue Jul 1, 2020 · 8 comments
Closed
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task

Comments

@ameliatqy
Copy link

ameliatqy commented Jul 1, 2020

🐛 Bug

So, I put 'training_epoch_end' function in my LightningModule. I have it return this dictionary
{'log':{'train_loss': tensor(0.3616, device='cuda:0'), 'epoch': 0}

I check the run_training_epoch_end function in the Pytorch library, it looks like it is working normally as log_epoch_metrics is showing the 'log' part in the dictionary produced by 'training_epoch_end' function
{'train_loss': tensor(0.3616, device='cuda:0'), 'epoch': 0}

So, they send it off to the logger. But there is problem, it is trying to combine the dictionary above with the results from the training step fo the next epoch. When I check the variable self._metrics_to_agg, I get the following result. Of course, it is impossible to combine these dictionaries as they have different keys. I guess the main problem is that the code is combining the log results of run_training_epoch_end function with the results of the next training batch.
[{'train_loss': tensor(0.3616, device='cuda:0'), 'epoch': 0}, {'loss': 0.48756, 'input': ..., 'ouput': ...}]

Any ideas to solve this problem? I will appreciate your help! Here is the whole error stack:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-5-a4569c0ab5cd> in <module>
      3 logging.getLogger('allennlp').setLevel(logging.INFO)
      4 
----> 5 train("/shared/labs/workflow/configs/ner_span_path_2.jsonnet", "/Dropbox/EvidLabs/atee/lab_data_enrichment_project/train/DataEnrichmentSpanPath/DataEnrichmentSpanPath__INV_M_span_path_42_43", AllenNlpLightningModule, tags=[])
      6 
      7 # train("/shared/labs/workflow/configs/candidate_span_grouper.jsonnet", "./DataEnrichmentBunchedCrfTaggerTest", AllenNlpLightningModule, tags=["Frozen"])

/code/evid-research/evid2/pytorch_lightning/trainer/config_file_trainer.py in train(config_file, project_dir, lightning_module, run_id, run_name_override, tags, dry_run)
    171     # 3 START TRAINING
    172     # ------------------------
--> 173     trainer.fit(model)
    174 
    175 

/code/pytorch-lightning/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders)
    971         # easier to avoid NCCL issues
    972         elif self.use_dp:
--> 973             self.dp_train(model)
    974 
    975         elif self.use_horovod:

/code/pytorch-lightning/pytorch_lightning/trainer/distrib_parts.py in dp_train(self, model)
    265         model = LightningDataParallel(model, device_ids=device_ids)
    266 
--> 267         self.run_pretrain_routine(model)
    268 
    269         model.forward = model_autocast_original_forward

/code/pytorch-lightning/pytorch_lightning/trainer/trainer.py in run_pretrain_routine(self, model)
   1154 
   1155         # CORE TRAINING LOOP
-> 1156         self.train()
   1157 
   1158     def test(

/code/pytorch-lightning/pytorch_lightning/trainer/training_loop.py in train(self)
    368                 # RUN TNG EPOCH
    369                 # -----------------
--> 370                 self.run_training_epoch()
    371 
    372                 if self.max_steps and self.max_steps <= self.global_step:

/code/pytorch-lightning/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
    480             # SAVE METRICS TO LOGGERS
    481             # -----------------------------------------
--> 482             self.save_train_loop_metrics_to_loggers(batch_idx, batch_output)
    483 
    484             # progress global step according to grads progress

/code/pytorch-lightning/pytorch_lightning/trainer/training_loop.py in save_train_loop_metrics_to_loggers(self, batch_idx, batch_output)
    564             # logs user requested information to logger
    565             print("batch_output.batch_log_metrics", batch_output.batch_log_metrics.keys())
--> 566             self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic)
    567 
    568     def save_loggers_in_training_loop(self, batch_idx):

/code/pytorch-lightning/pytorch_lightning/trainer/logging.py in log_metrics(self, metrics, grad_norm_dic, step)
     72         if self.is_global_zero and self.logger is not None:
     73             print("scalar_metrics", scalar_metrics.keys())
---> 74             self.logger.agg_and_log_metrics(scalar_metrics, step=step)
     75             self.logger.save()
     76 

/code/pytorch-lightning/pytorch_lightning/loggers/base.py in agg_and_log_metrics(self, metrics, step)
    131         """
    132         print("metrics", metrics.keys())
--> 133         agg_step, metrics_to_log = self._aggregate_metrics(metrics=metrics, step=step)
    134 
    135         if metrics_to_log:

/code/pytorch-lightning/pytorch_lightning/loggers/base.py in _aggregate_metrics(self, metrics, step)
     88 
     89         # compute the metrics
---> 90         agg_step, agg_mets = self._reduce_agg_metrics()
     91 
     92         # as new step received reset accumulator

/code/pytorch-lightning/pytorch_lightning/loggers/base.py in _reduce_agg_metrics(self)
    109             print("self._agg_key_funcs", self._agg_key_funcs)
    110             print("self._agg_default_func", self._agg_default_func)
--> 111             agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)
    112         return self._prev_step, agg_mets
    113 

/code/pytorch-lightning/pytorch_lightning/loggers/base.py in merge_dicts(dicts, agg_key_funcs, default_func)
    386             print("fn or default_func", fn or default_func)
    387             print("values_to_agg", values_to_agg)
--> 388             d_out[k] = (fn or default_func)(values_to_agg)
    389 
    390     return d_out

<__array_function__ internals> in mean(*args, **kwargs)

/venvs/dev3.7/lib/python3.7/site-packages/numpy/core/fromnumeric.py in mean(a, axis, dtype, out, keepdims)
   3333 
   3334     return _methods._mean(a, axis=axis, dtype=dtype,
-> 3335                           out=out, **kwargs)
   3336 
   3337 

/venvs/dev3.7/lib/python3.7/site-packages/numpy/core/_methods.py in _mean(a, axis, dtype, out, keepdims)
    149             is_float16_result = True
    150 
--> 151     ret = umr_sum(arr, axis, dtype, out, keepdims)
    152     if isinstance(ret, mu.ndarray):
    153         ret = um.true_divide(

TypeError: unsupported operand type(s) for +: 'dict' and 'dict'

To Reproduce

Steps to reproduce the behavior:

  1. Add a training_epoch_end function to your Lightning Module and run it. You can use mine in the code sample section if you want. The key is that the "log" section of training_epoch_end dictionary output must be of a different format than the dictionary that contains the results of your training sample while training (like in the example I provided {'train_loss': tensor(0.3616, device='cuda:0'), 'epoch': 0} and {'loss': 0.48756, 'input': ..., 'ouput': ...} are different formats as they don't share the same keys).

Code sample

 def training_epoch_end(self, outputs):
        print("training_epoch_end")
        print(len(outputs))

        prefix = "train_"
        metric_modules = self.training_metric_modules

        # Handle loss:
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        result = {'log': {}, 'progress_bar': {}}
        result[prefix + 'loss'] = avg_loss
        result['log'][prefix + 'loss'] = avg_loss
        result['progress_bar'][prefix + 'loss'] = avg_loss

        # Add tracking variables
        result['log']['epoch'] = self.current_epoch

        return result

Expected behavior

I thought I would be able to run training_epoch_end function with no combination with the training samples in the next epoch. I expected no error, like running validation_epoch_end.

Environment

  • CUDA:
    - GPU:
    - Tesla V100-SXM2-16GB
    - available: True
    - version: 10.2
  • Packages:
    - numpy: 1.18.5
    - pyTorch_debug: False
    - pyTorch_version: 1.5.0
    - pytorch-lightning: 0.8.4
    - tensorboard: 2.2.2
    - tqdm: 4.46.1
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    -
    - processor: x86_64
    - python: 3.7.7
    - version: Support of different batch types #113-Ubuntu SMP Wed Jan 29 14:54:54 UTC 2020
@ameliatqy ameliatqy added bug Something isn't working help wanted Open to be worked on labels Jul 1, 2020
@github-actions
Copy link
Contributor

github-actions bot commented Jul 1, 2020

Hi! thanks for your contribution!, great first issue!

@ameliatqy
Copy link
Author

An update to the situation: I think I found the cause of the error. It seems that the self.global_step wasn't incrementing after processing training_epoch_end. Basically, the counter would increment for validation_epoch_end.
validation_epoch_end -> step 547

And then for training_epoch_end
training_epoch_end -> step 548

But it wouldn't increment as we progress to the next epoch
First batch of next epoch -> step 548

This triggered this piece of code from _aggregate_metrics, putting the log output of training_epoch_end and first batch of next epoch together as explained in the initial problem statement

        if step == self._prev_step:
            self._metrics_to_agg.append(metrics)
            return step, None

My suggestion to the solution is to add self.global_step += 1 to the function run_training_epoch_end like so:

    def run_training_epoch_end(self, epoch_output):
        model = self.get_model()
        if self.is_overridden('training_epoch_end', model=model):
            .......
            self.global_step += 1

Let me know if you foresee any issues or problems with this solution. If you would like me to submit a pull request, I would be happy to do so.

@williamFalcon
Copy link
Contributor

@ameliatqy please submit a PR. good catch!

@williamFalcon williamFalcon added the priority: 0 High priority task label Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
@awaelchli
Copy link
Contributor

awaelchli commented Jul 2, 2020

I don't think this is correct. Now it looks like we have a step update on the last batch and additionally in the epoch end call, meaning that after n epochs global step > n * num_batches_per_epoch .
I think the real fix should be to combine the logging of the last batch and the training epoch end into one global step update.
Otherwise you will get misaligned logs.
@PyTorchLightning/core-contributors

williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
@awaelchli
Copy link
Contributor

awaelchli commented Jul 2, 2020

Untitled-1

I made this experiment with before and after.
Here I do

  • 10 epochs
  • each epoch has 10 train batches, no val batches
  • training step adds a train_loss
  • training epoch end adds a epoch_loss

Expected: All plots end with step 100 = 10 * 10 batches = global step

Neither the version before nor the current one work as expected.

williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
williamFalcon added a commit that referenced this issue Jul 2, 2020
@ameliatqy
Copy link
Author

Untitled-1

I made this experiment with before and after.
Here I do

  • 10 epochs
  • each epoch has 10 train batches, no val batches
  • training step adds a train_loss
  • training epoch end adds a epoch_loss

Expected: All plots end with step 100 = 10 * 10 batches = global step

Neither the version before nor the current one work as expected.

That's a good point - I tried to look into the problem with the previous one and even encountered another error. If I set row_log_interval = 1 in the Trainer class (therefore, forcing all the steps to log) - I encounter the same error again as described in my first post. This time, instead of the training metrics and the member of the first batch, it is trying to combine the validation metrics and the metrics from the last member of the batch.

I will continue working on a solution for this.

@ameliatqy
Copy link
Author

Okay, here is my second shot at a solution for this issue. To keep the steps consistent as addressed by @awaelchli , I decided to the combine the metrics for the last batch, the training_epoch_end metrics and the validation_epoch_end metrics. So I did two things:

  1. I changed how the steps were incremented in run_training_epoch. The steps increment in the batch loop as usual until we get to the last step, in which we will increment the global_step at the very end of the function. To accomplish this, I added an if statement to the incrementer to prevent it from incrementing on the last_batch:
 if not is_last_batch:
      self.increment_accumulated_grad_global_step()

And a line of code to increment the global step by one at the very end of run_training_epoch

self.global_step += 1
  1. I also changed how metrics from the same steps were combined. In the case that the metric keys are different (as in the case of training batch metrics, training_epoch_end metrics and the validation_epoch_end metrics ), we just put all the keys together in one dictionary. So, as an example, we will collect our metrics from the same step in self._metrics_to_agg
self._metrics_to_agg = [{"loss": ...}, {"train_epoch_end_loss": ...}, {"val_epoch_end_loss": ...}]

Which _reduce_agg_metrics will transform into:

agg_mets = {"loss": ..., "train_epoch_end_loss": ..., "val_epoch_end_loss": ...}

To do this, I just edited _reduce_agg_metrics into the following function"

    def _reduce_agg_metrics(self):
        ...
        else:
            # pop out 'epoch' because it is a metric automatically added in by log_metrics and will count as a
            # duplicate. If you want to get rid of this, I would suggest you should get rid of `scalar_metrics[
            # 'epoch'] = self.current_epoch` in TrainerLoggingMixin.log_metrics()
            for i in range(len(self._metrics_to_agg)):
                self._metrics_to_agg[i].pop("epoch")

            # check if dictionary keys are unique
            agg_keys = set([key for mets in self._metrics_to_agg for key in mets.keys()])
            num_keys = sum([len(mets) for mets in self._metrics_to_agg])

            if len(agg_keys) == num_keys:
                # if dictionary keys are unique
                agg_mets = self._metrics_to_agg[0]
                for mets in self._metrics_to_agg[1:]:
                    agg_mets.update(mets)
            else:
                agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)

        return self._prev_step, agg_mets

I guess for this solution to work, you have to make sure that your keys for the metrics for the last batch, the training_epoch_end metrics and the validation_epoch_end metrics are all different and unique (which I believe is already the standard case?).

You can see the steps are fixed. Ran 10 batches for 5 epochs.

Screen Shot 2020-07-02 at 4 06 42 PM

Screen Shot 2020-07-02 at 4 07 06 PM

I decided to submit a Draft PR (#2475) because it is easier to view the changes that way as I made changes in multiple areas. Let me know if you see any issues with it

@ameliatqy
Copy link
Author

ameliatqy commented Jul 3, 2020

My bad, I can't rebase - please see #2478 for the new pull request (disregard #2475)

williamFalcon added a commit that referenced this issue Jul 3, 2020
* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* Fixes #2455

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test
@ameliatqy ameliatqy mentioned this issue Aug 15, 2020
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants