From 326679f669fc9385e5f69a004bdca82266155c09 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 22:15:44 +0000 Subject: [PATCH] Address code review for deepspeed --- pytorch_lightning/plugins/training_type/deepspeed.py | 11 +++++------ .../trainer/connectors/accelerator_connector.py | 5 +++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 354ef5944ef42f..69fdb4c19a4b6e 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -158,13 +158,12 @@ def _load_config(self, config): rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") config = os.environ[self.DEEPSPEED_ENV_VAR] if isinstance(config, str) or isinstance(config, Path): - if os.path.exists(config): - with open(config) as f: - config = json.load(f) - else: + if not os.path.isfile(config): raise MisconfigurationException( f"You passed in a path to a DeepSpeed config but the path does not exist: {config}" ) + with open(config) as f: + config = json.load(f) return config def pre_dispatch(self): @@ -198,7 +197,7 @@ def _init_scheduler_optimizer(self): optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers( self.lightning_module ) - if (len(optimizers) != 1) or len(schedulers) > 1: + if len(optimizers) > 1 or len(schedulers) > 1: raise MisconfigurationException( "DeepSpeed currently only supports single optimizer, single optional scheduler." ) @@ -234,7 +233,7 @@ def _initialize_deepspeed_inference(self, model): self.model_to_device() self.pre_configure_ddp() - self._model = DistributedDataParallel( + self.model = DistributedDataParallel( model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs, diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 5549a29473fc7c..6bce1cd004642c 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -34,7 +34,6 @@ DeepSpeedPrecisionPlugin, HorovodPlugin, NativeMixedPrecisionPlugin, - Plugin, PrecisionPlugin, ShardedNativeMixedPrecisionPlugin, SingleDevicePlugin, @@ -147,7 +146,9 @@ def __init__( self.replace_sampler_ddp = replace_sampler_ddp - def handle_given_plugins(self, plugins: Optional[Union[Plugin, Sequence]]): + def handle_given_plugins( + self, plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]] + ): plugins = plugins if plugins is not None else [] if isinstance(plugins, str):