diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 6eafc19d863ee..f75b0a1f1a582 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -130,6 +130,7 @@ def assert_pred_same(): trainer.fit(model) +@pytest.mark.spawn @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_multi_gpu_none_backend(tmpdir): """Make sure when using multiple GPUs the user can't use `distributed_backend = None`."""