From c64578eb59e6772b339a6ee2bdac88c8d5982463 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 26 Mar 2021 17:41:35 +0100 Subject: [PATCH] Remove tests Discussed with SeanNaren --- benchmarks/test_sharded_parity.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 065b1531781d0..28cbd7828b108 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -122,7 +122,6 @@ def plugin_parity_test( model_cls: Type[SeedTrainLoaderModel], seed: int = 42, gpus: int = 0, - accelerator: str = 'ddp_spawn', precision: int = 32, max_percent_speed_diff: float = 0.1, ): @@ -134,7 +133,6 @@ def plugin_parity_test( model_cls: Model class to use for test. seed: Seed for generators. Note that this does not handle the seed for data-loading on multi-process. gpus: Number of GPUS to enable. - accelerator: Accelerator to use. precision: Whether to use AMP or normal FP32 training. max_percent_speed_diff: The maximum speed difference compared to normal DDP training. This is more a safety net for variability in CI which can vary in speed, not for benchmarking. @@ -151,7 +149,7 @@ def plugin_parity_test( max_epochs=1, gpus=gpus, precision=precision, - accelerator=accelerator, + accelerator='ddp_spawn', ) max_memory_ddp, ddp_time = record_ddp_fit_model_stats(trainer=trainer, model=ddp_model, use_cuda=use_cuda) @@ -225,22 +223,3 @@ def test_ddp_spawn_sharded_plugin(kwargs): # TODO: decrease speed diff since only 2 GPUs sharding 2 optimizers kwargs['max_percent_speed_diff'] = 0.25 plugin_parity_test(**kwargs) - - -@RunIf(min_gpus=2, fairscale=True, special=True) -def test_ddp_sharded_plugin(tmpdir): - plugin_parity_test( - gpus=2, - accelerator='ddp', - model_cls=SeedTrainLoaderModel, - ) - - -@RunIf(min_gpus=2, fairscale=True, special=True, amp_native=True) -def test_ddp_sharded_plugin_amp(tmpdir): - plugin_parity_test( - gpus=2, - accelerator='ddp', - precision=16, - model_cls=SeedTrainLoaderModel, - )