diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 2eb3a40b99d63e..7b50165babf49b 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -545,6 +545,7 @@ def test_stage3_nvme_offload(self): ds_config_zero3_dict = self.get_config_dict(ZERO3) ds_config_zero3_dict["zero_optimization"]["offload_optimizer"] = nvme_config ds_config_zero3_dict["zero_optimization"]["offload_param"] = nvme_config + ds_config_zero3_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero3_dict) with CaptureLogger(deepspeed_logger) as cl: trainer.train()