Skip to content

Commit

Permalink
Merge 326679f into b7c2e0a
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored Feb 17, 2021
2 parents b7c2e0a + 326679f commit b3fe91f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
11 changes: 5 additions & 6 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,12 @@ def _load_config(self, config):
rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
config = os.environ[self.DEEPSPEED_ENV_VAR]
if isinstance(config, str) or isinstance(config, Path):
if os.path.exists(config):
with open(config) as f:
config = json.load(f)
else:
if not os.path.isfile(config):
raise MisconfigurationException(
f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
)
with open(config) as f:
config = json.load(f)
return config

def pre_dispatch(self):
Expand Down Expand Up @@ -198,7 +197,7 @@ def _init_scheduler_optimizer(self):
optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers(
self.lightning_module
)
if (len(optimizers) != 1) or len(schedulers) > 1:
if len(optimizers) > 1 or len(schedulers) > 1:
raise MisconfigurationException(
"DeepSpeed currently only supports single optimizer, single optional scheduler."
)
Expand Down Expand Up @@ -234,7 +233,7 @@ def _initialize_deepspeed_inference(self, model):
self.model_to_device()

self.pre_configure_ddp()
self._model = DistributedDataParallel(
self.model = DistributedDataParallel(
model,
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
DeepSpeedPrecisionPlugin,
HorovodPlugin,
NativeMixedPrecisionPlugin,
Plugin,
PrecisionPlugin,
ShardedNativeMixedPrecisionPlugin,
SingleDevicePlugin,
Expand Down Expand Up @@ -147,7 +146,9 @@ def __init__(

self.replace_sampler_ddp = replace_sampler_ddp

def handle_given_plugins(self, plugins: Optional[Union[Plugin, Sequence]]):
def handle_given_plugins(
self, plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]]
):
plugins = plugins if plugins is not None else []

if isinstance(plugins, str):
Expand Down

0 comments on commit b3fe91f

Please sign in to comment.