Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update PyTorch 2.4 tests #20079

Merged
merged 4 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def _load_checkpoint(
return metadata

if _is_full_checkpoint(path):
checkpoint = torch.load(path, mmap=True, map_location="cpu")
checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=False)
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)

state_dict_options = StateDictOptions(
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_fabric/plugins/precision/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def __init__(self):
model = MyModel()
ckpt_path = tmp_path / "foo.ckpt"
torch.save(state_dict, ckpt_path)
torch.load(str(ckpt_path), mmap=True)
torch.load(str(ckpt_path), mmap=True, weights_only=True)
keys = model.load_state_dict(state_dict, strict=True, assign=True) # quantizes
assert not keys.missing_keys
assert model.l.weight.device.type == "cuda"
Expand Down Expand Up @@ -258,7 +258,7 @@ def forward(self, x):
fabric = Fabric(accelerator="cuda", devices=1, plugins=BitsandbytesPrecision("nf4-dq"))
model = Model()
model = fabric.setup(model)
state_dict = torch.load(tmp_path / "checkpoint.pt")
state_dict = torch.load(tmp_path / "checkpoint.pt", weights_only=True)
model.load_state_dict(state_dict)
assert model.linear.weight.dtype == torch.uint8
assert model.linear.weight.shape == (128, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def test_train_save_load(precision, tmp_path):
assert state["coconut"] == 11


@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_save_full_state_dict(tmp_path):
"""Test that ModelParallelStrategy saves the full state into a single file with
Expand Down Expand Up @@ -401,6 +402,7 @@ def test_save_full_state_dict(tmp_path):
_train(fabric, model, optimizer)


@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_load_full_state_dict_into_sharded_model(tmp_path):
"""Test that the strategy can load a full-state checkpoint into a distributed model."""
Expand Down
2 changes: 2 additions & 0 deletions tests/tests_pytorch/strategies/test_ddp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def configure_optimizers(self):
return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1)


# ZeroRedundancyOptimizer internally calls `torch.load` with `weights_only` not set, triggering the FutureWarning
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True)
@pytest.mark.parametrize("strategy", [pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn"])
def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(strategy, tmp_path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def test_load_full_state_checkpoint_into_regular_model(tmp_path):
trainer.strategy.barrier()


@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
def test_load_standard_checkpoint_into_distributed_model(tmp_path):
"""Test that a regular checkpoint (weights and optimizer states) can be loaded into a distributed model."""
Expand Down Expand Up @@ -458,6 +459,7 @@ def test_load_standard_checkpoint_into_distributed_model(tmp_path):
trainer.strategy.barrier()


@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_save_load_sharded_state_dict(tmp_path):
"""Test saving and loading with the distributed state dict format."""
Expand Down
Loading