diff --git a/CHANGELOG.md b/CHANGELOG.md index a1ae50dccc58a..b2b0d8feb528e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed lr key name in case of param groups in LearningRateLogger ([#1719](https://github.com/PyTorchLightning/pytorch-lightning/pull/1719)) +- Fixed saving native AMP scaler state (introduced in [#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561)) + + ## [0.7.5] - 2020-04-27 ### Changed diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 4f474b761e94f..437a89e42470f 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -338,8 +338,8 @@ def dump_checkpoint(self): checkpoint['state_dict'] = model.state_dict() - # restore native amp scaling - if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: + # save native amp scaling + if self.use_amp and self.use_native_amp: checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() if hasattr(model, "hparams"):