Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AMP for validation, prediction and testing #6565

Merged
merged 10 commits into from
Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,21 @@ def train_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context"""
with torch.cuda.amp.autocast():
yield

@contextmanager
def val_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context"""
with torch.cuda.amp.autocast():
yield

@contextmanager
def test_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context"""
with torch.cuda.amp.autocast():
yield

@contextmanager
def predict_context(self) -> Generator[None, None, None]:
"""Enable autocast context"""
with torch.cuda.amp.autocast():
yield
26 changes: 24 additions & 2 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,34 @@
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf


class AMPTestModel(BoringModel):

def training_step(self, batch, batch_idx):
def _step(self, batch, batch_idx):
assert torch.is_autocast_enabled()
output = self(batch)
assert output.dtype == torch.float16
loss = self.loss(batch, output)
return {"loss": loss}

def training_step(self, batch, batch_idx):
return self._step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
self._step(batch, batch_idx)

def test_step(self, batch, batch_idx):
self._step(batch, batch_idx)

def predict(self, batch, batch_idx):
assert torch.is_autocast_enabled()
output = self(batch)
assert output.dtype == torch.float16
return output



@pytest.mark.skip(reason='dp + amp not supported currently') # TODO
Expand All @@ -54,6 +70,8 @@ def test_amp_single_gpu_dp(tmpdir):
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)
trainer.test(model)
trainer.predict(model, DataLoader(RandomDataset(32, 64)))

assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

Expand All @@ -73,6 +91,8 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)
trainer.test(model)
trainer.predict(model, DataLoader(RandomDataset(32, 64)))
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


Expand Down Expand Up @@ -112,6 +132,8 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir):
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)
trainer.test(model)
trainer.predict(model, DataLoader(RandomDataset(32, 64)))
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


Expand Down