Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] mmrazor.engine.runner.quantization_loops.QATValLoop calls after_val_epoch hook twice with different keys, causing mmengine.hooks.checkpoint_hook._save_best_checkpoint() to fail with KeyError for the save_best config #637

Open
elisa-aleman opened this issue Apr 10, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@elisa-aleman
Copy link

elisa-aleman commented Apr 10, 2024

Describe the bug

During QAT training of models with config files publicly available, such as RTMPose-tiny, I ran into this issue where the original file has the config:

default_hooks.checkpoint.save_best='coco/AP'

This works normally in non-quantized training.

However when inheriting the _base_ in the QAT config, mmrazor.engine.runner.quantization_loops.QATValLoop calls after_val_epoch hook twice with different keys as seen here

    def run(self) -> dict:
        """Launch validation."""
        self.runner.call_hook('before_val')
        self.runner.call_hook('before_val_epoch')
        self.runner.model.eval()
        for idx, data_batch in enumerate(self.dataloader):
            self.run_iter(idx, data_batch, self.runner.model)

        # compute metrics
        metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
        qat_metrics = dict()
        for key, value in metrics.items():
            qat_key = 'qat.' + key
            ori_key = 'original.' + key
            qat_metrics[qat_key] = value
            self.runner.message_hub.log_scalars.pop(f'val/{ori_key}', None)

        self.runner.call_hook('after_val_epoch', metrics=qat_metrics)

        self.runner.call_hook('before_val_epoch')
        self.runner.model.eval()
        for idx, data_batch in enumerate(self.dataloader):
            self.run_iter(idx, data_batch, self.architecture)

        # compute metrics
        metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
        qat_metrics = dict()
        for key, value in metrics.items():
            qat_key = 'qat.' + key
            ori_key = 'original.' + key
            qat_metrics[ori_key] = value
            self.runner.message_hub.log_scalars.pop(f'val/{qat_key}', None)

        self.runner.call_hook('after_val_epoch', metrics=qat_metrics)

        self.runner.call_hook('after_val')
        return qat_metrics

causing mmengine.hooks.checkpoint_hook._savebest_checkpoint() to fail unless save_best is overwritten with 'auto'.

edit: 'auto' still causes only the first occurrence to be set as key_indicator from this line.

However, one might need a different setting besides 'auto' in some fringe circumstances. I suggest the code only attempts to save the best_checkpoint once, or that the prefix 'qat.' and 'original.' each be added temporarily to the key_indicators set key and then removed each time the hook is called.

Reproduces the error - error message

Traceback (most recent call last):
...
  File ".../site-packages/mmrazor/engine/runner/quantization_loops.py" line 266, in run
    self.runner.call_hook('after_val_epoch', metrics=qat_metrics)
  File ".../site-packages/mmengine/runner/runner.py" line 1839, in call_hook
    getattr(hook, fn_name)(self, *kwargs)
  File ".../site-packages/mmengine/hooks/checkpoint_hook.py" line 361, in after_val_epoch
    self._save_best_checkpoint(runner, metrics)
  File ".../site-packages/mmengine/hooks/checkpoint_hook.py" line 505, in _save_best_checkpoint
    key_score = metrics[key_indicator]
                ~~~~~~~^^^^^^^^^^^^^^^

Suggested fix

Edit: I changed the fix to prefer the qat metrics to save the best checkpoint rather than the architecture metrics.

@@ -255,12 +255,12 @@ class QATValLoop(ValLoop):
              self.run_iter(idx, data_batch, self.runner.model)
          # compute metrics
+         self.runner.logger.info(f'Evaluating QAT model')
          metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
          qat_metrics = dict()
          for key, value in metrics.items():
-             qat_key = 'qat.' + key
              ori_key = 'original.' + key
-             qat_metrics[qat_key] = value
+             qat_metrics[key] = value
              self.runner.message_hub.log_scalars.pop(f'val/{ori_key}, None)
 
          self.runner.call_hook('after_val_epoch', metrics=qat_metrics)
@@ -271,15 +271,10 @@ class QATValLoop(ValLoop):
              self.run_iter(idx, data_batch, self.architecture)

          # compute metrics
+         self.runner.logger.info(f'Evaluating original model architecture')
          metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
-         qat_metrics = dict()
          for key, value in metrics.items():
-             qat_key = 'qat.' + key
-             ori_key = 'original.' + key
-             qat_metrics[ori_key] = value
-             self.runner.message_hub.log_scalars.pop(f'val/{qat_key}, None)
-
-        self.runner.call_hook('after_val_epoch', metrics=qat_metrics)
+             self.runner.message_hub.log_scalars.pop(f'val/{key}, None)

          self.runner.call_hook('after_val')
          return qat_metrics            

The only issue is that the after_val_epoch hook doesn't get called for the QAT architecture model and some of the logs are incomplete because of that.

@elisa-aleman elisa-aleman added the bug Something isn't working label Apr 10, 2024
@elisa-aleman elisa-aleman changed the title [Bug] (temporary fix) mmrazor.engine.runner.quantization_loops.QATValLoop calls after_val_epoch hook twice with different keys, causing mmengine.hooks.checkpoint_hook._savebest_checkpoint() to fail unless save_best is overwritten with 'auto' [Bug] (temporary fix) mmrazor.engine.runner.quantization_loops.QATValLoop calls after_val_epoch hook twice with different keys, causing mmengine.hooks.checkpoint_hook._save_best_checkpoint() to fail unless save_best is overwritten with 'auto' Apr 10, 2024
@elisa-aleman elisa-aleman changed the title [Bug] (temporary fix) mmrazor.engine.runner.quantization_loops.QATValLoop calls after_val_epoch hook twice with different keys, causing mmengine.hooks.checkpoint_hook._save_best_checkpoint() to fail unless save_best is overwritten with 'auto' [Bug] (temporary fix) mmrazor.engine.runner.quantization_loops.QATValLoop calls after_val_epoch hook twice with different keys, causing mmengine.hooks.checkpoint_hook._save_best_checkpoint() to fail with KeyError for the save_best config Apr 10, 2024
@elisa-aleman elisa-aleman changed the title [Bug] (temporary fix) mmrazor.engine.runner.quantization_loops.QATValLoop calls after_val_epoch hook twice with different keys, causing mmengine.hooks.checkpoint_hook._save_best_checkpoint() to fail with KeyError for the save_best config [Bug] mmrazor.engine.runner.quantization_loops.QATValLoop calls after_val_epoch hook twice with different keys, causing mmengine.hooks.checkpoint_hook._save_best_checkpoint() to fail with KeyError for the save_best config Apr 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant