Skip to content

Commit

Permalink
try spawn
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jun 12, 2020
1 parent d786048 commit f672ab6
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions tests/metrics/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def tensor_test_metric(*args, **kwargs):
assert result.item() == 5.


def setup_ddp(rank, worldsize, ):
def _setup_ddp(rank, worldsize):
import os

os.environ['MASTER_ADDR'] = 'localhost'
Expand All @@ -114,8 +114,8 @@ def setup_ddp(rank, worldsize, ):
dist.init_process_group("gloo", rank=rank, world_size=worldsize)


def ddp_test_fn(rank, worldsize):
setup_ddp(rank, worldsize)
def _ddp_test_fn(rank, worldsize):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([1.], device='cuda:0')

reduced_tensor = _sync_ddp_if_available(tensor)
Expand All @@ -124,16 +124,17 @@ def ddp_test_fn(rank, worldsize):
'Sync-Reduce does not work properly with DDP and Tensors'


@pytest.mark.spawn
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_sync_reduce_ddp():
"""Make sure sync-reduce works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()

worldsize = 2
mp.spawn(ddp_test_fn, args=(worldsize,), nprocs=worldsize)
mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize)

dist.destroy_process_group()
# dist.destroy_process_group()


def test_sync_reduce_simple():
Expand Down Expand Up @@ -168,7 +169,7 @@ def tensor_test_metric(*args, **kwargs):


def _ddp_test_tensor_metric(rank, worldsize):
setup_ddp(rank, worldsize)
_setup_ddp(rank, worldsize)
_test_tensor_metric(True)


Expand All @@ -179,8 +180,7 @@ def test_tensor_metric_ddp():

world_size = 2
mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size)

dist.destroy_process_group()
# dist.destroy_process_group()


def test_tensor_metric_simple():
Expand Down Expand Up @@ -209,17 +209,18 @@ def numpy_test_metric(*args, **kwargs):


def _ddp_test_numpy_metric(rank, worldsize):
setup_ddp(rank, worldsize)
_setup_ddp(rank, worldsize)
_test_numpy_metric(True)


@pytest.mark.spawn
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_numpy_metric_ddp():
tutils.reset_seed()
tutils.set_random_master_port()
world_size = 2
mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size)
dist.destroy_process_group()
# dist.destroy_process_group()


def test_numpy_metric_simple():
Expand Down

0 comments on commit f672ab6

Please sign in to comment.