diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index aff19fb253f38..1f5b73f9be364 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1033,7 +1033,7 @@ def run_pretrain_routine(self, model: LightningModule): self.early_stop_callback._validate_condition_metric(callback_metrics) # clear cache before training - if self.on_gpu: + if self.on_gpu and self.root_gpu is not None: # use context because of: # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 with torch.cuda.device(f'cuda:{self.root_gpu}'):