Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Feb 17, 2022
1 parent 3152f81 commit 2c2e5ac
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
12 changes: 3 additions & 9 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def _choose_accelerator(self) -> str:
return "cpu"

def _set_parallel_devices_and_init_accelerator(self) -> None:
# TODO add device availability check
self._parallel_devices: List[Union[int, torch.device]] = []

if isinstance(self._accelerator_flag, Accelerator):
Expand All @@ -451,8 +452,6 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
elif self._accelerator_flag == "gpu":
self.accelerator = GPUAccelerator()
self._set_devices_flag_if_auto_passed()
# TODO add device availablity check for all devices, not only GPU
self._check_device_availability()
if isinstance(self._devices_flag, int) or isinstance(self._devices_flag, str):
self._devices_flag = int(self._devices_flag)
self._parallel_devices = (
Expand Down Expand Up @@ -481,12 +480,6 @@ def _set_devices_flag_if_auto_passed(self) -> None:
if self._devices_flag == "auto" or not self._devices_flag:
self._devices_flag = self.accelerator.auto_device_count()

def _check_device_availability(self) -> None:
if not self.accelerator.is_available():
raise MisconfigurationException(
f"You requested {self._accelerator_flag}, " f"but {self._accelerator_flag} is not available"
)

def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:
if isinstance(self._cluster_environment_flag, ClusterEnvironment):
return self._cluster_environment_flag
Expand Down Expand Up @@ -651,7 +644,8 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
return NativeMixedPrecisionPlugin(self._precision_flag, device)

if self._amp_type_flag == AMPType.APEX:
return ApexMixedPrecisionPlugin(self._amp_level_flag) # type: ignore
self._amp_level_flag = self._amp_level_flag or "O2"
return ApexMixedPrecisionPlugin(self._amp_level_flag)

raise RuntimeError("No precision set")

Expand Down
5 changes: 3 additions & 2 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,9 @@ def test_accelerator_cpu(mack_gpu_avalible):

with pytest.raises(MisconfigurationException, match="You requested gpu"):
trainer = Trainer(gpus=1)
with pytest.raises(MisconfigurationException, match="You requested gpu, but gpu is not available"):
trainer = Trainer(accelerator="gpu")
# TODO enable this test when add device availability check
# with pytest.raises(MisconfigurationException, match="You requested gpu, but gpu is not available"):
# trainer = Trainer(accelerator="gpu")
with pytest.raises(MisconfigurationException, match="You requested gpu:"):
trainer = Trainer(accelerator="cpu", gpus=1)

Expand Down
7 changes: 6 additions & 1 deletion tests/strategies/test_deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,12 @@ def test_deepspeed_precision_choice(amp_backend, precision, tmpdir):
"""

trainer = Trainer(
fast_dev_run=True, default_root_dir=tmpdir, strategy="deepspeed", amp_backend=amp_backend, precision=precision
fast_dev_run=True,
default_root_dir=tmpdir,
accelerator="gpu",
strategy="deepspeed",
amp_backend=amp_backend,
precision=precision,
)

assert isinstance(trainer.strategy, DeepSpeedStrategy)
Expand Down

0 comments on commit 2c2e5ac

Please sign in to comment.