diff --git a/CHANGELOG.md b/CHANGELOG.md index 8064134b772789..ce55dccce35974 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -92,6 +92,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565)) + - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 94e6cf376b03a5..d19b05358cdd58 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -93,3 +93,21 @@ def train_step_context(self) -> Generator[autocast, None, None]: """Enable autocast context""" with torch.cuda.amp.autocast(): yield + + @contextmanager + def val_step_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield + + @contextmanager + def test_step_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield + + @contextmanager + def predict_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 53ec32764f3ed8..2f16f2fe64e75c 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -17,6 +17,7 @@ import pytest import torch from torch import optim +from torch.utils.data import DataLoader import tests.helpers.utils as tutils from pytorch_lightning import Trainer @@ -24,17 +25,35 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _APEX_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset class AMPTestModel(BoringModel): - def training_step(self, batch, batch_idx): + def _step(self, batch, batch_idx): assert torch.is_autocast_enabled() output = self(batch) assert output.dtype == torch.float16 loss = self.loss(batch, output) - return {"loss": loss} + return loss + + def training_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"loss": output} + + def validation_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"x": output} + + def test_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"y": output} + + def predict(self, batch, batch_idx, dataloader_idx=None): + assert torch.is_autocast_enabled() + output = self(batch) + assert output.dtype == torch.float16 + return output @pytest.mark.skip(reason='dp + amp not supported currently') # TODO @@ -54,6 +73,8 @@ def test_amp_single_gpu_dp(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -73,6 +94,8 @@ def test_amp_single_gpu_ddp_spawn(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -112,6 +135,8 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"