From 5947b39f54fa746e62222524e572153369883248 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Feb 2021 14:53:09 +0100 Subject: [PATCH] ... --- pytorch_lightning/trainer/deprecated_api.py | 3 +++ pytorch_lightning/trainer/trainer.py | 3 ++- pytorch_lightning/utilities/data.py | 4 +--- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 8e2e90dda8a7c..ddd54961c558c 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -133,6 +133,9 @@ def use_single_gpu(self, val: bool) -> None: if val: self.accelerator_connector._device_type = DeviceType.GPU + +class DeprecatedModelAttributes: + def get_model(self) -> LightningModule: rank_zero_warn( "The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e0bc6d51dbb2b..7ad61020ab099 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -45,7 +45,7 @@ from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin -from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes +from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes, DeprecatedModelAttributes from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin @@ -80,6 +80,7 @@ class Trainer( TrainerTrainingTricksMixin, TrainerDataLoadingMixin, DeprecatedDistDeviceAttributes, + DeprecatedModelAttributes, ): @overwrite_by_env_vars diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 6b887b8526f90..a73299e2af77b 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -30,9 +30,7 @@ def has_len(dataloader: DataLoader) -> bool: try: # try getting the length if len(dataloader) == 0: - raise ValueError( - '`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch' - ) + raise ValueError('`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch') has_len = True except TypeError: has_len = False