Skip to content

Commit

Permalink
Attempt to add broken test to mimic transformers use case (#2272)
Browse files Browse the repository at this point in the history
* Attempt to add broken test

* use wandb logger

* Update test_amp.py

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
sshleifer and williamFalcon authored Jun 19, 2020
1 parent 54acc79 commit e780072
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,33 @@ def test_amp_multi_gpu(tmpdir, backend):
assert result


@pytest.mark.spawn
@pytest.mark.parametrize("backend", ['dp', 'ddp'])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_wandb(tmpdir, backend):
"""Make sure DP/DDP + AMP work."""
from pytorch_lightning.loggers import WandbLogger
tutils.set_random_master_port()

model = EvalModelTemplate()
logger = WandbLogger(name='utest')

trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
gpus=2,
distributed_backend=backend,
precision=16,
logger=logger,

)
# tutils.run_model_test(trainer_options, model)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result
trainer.test(model)


@pytest.mark.spawn
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_amp_gpu_ddp_slurm_managed(tmpdir):
Expand Down

0 comments on commit e780072

Please sign in to comment.