Skip to content

Commit

Permalink
Fix usage of fs.listdir in CheckpointConnector (#15413)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: otaj <6065855+otaj@users.noreply.github.com>
(cherry picked from commit ee8a57d)
  • Loading branch information
leoleoasd authored and lexierule committed Nov 10, 2022
1 parent 43c7778 commit c8ba921
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an attribute error in `ColossalAIStrategy` at import time when `torch.distributed` is not available ([#15535](https://github.com/Lightning-AI/lightning/pull/15535))

- Fixed an issue when calling `fs.listdir` with file URI instead of path in `CheckpointConnector` ([#15413](https://github.com/Lightning-AI/lightning/pull/15413))

- Fixed an issue with the `BaseFinetuning` callback not setting the `track_running_stats` attribute for batch normaliztion layers ([#15063](https://github.com/Lightning-AI/lightning/pull/15063))


Expand Down
17 changes: 11 additions & 6 deletions src/pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from typing import Any, Dict, Optional

import torch
from fsspec.core import url_to_fs
from fsspec.implementations.local import LocalFileSystem
from torch import Tensor
from torchmetrics import Metric

Expand Down Expand Up @@ -59,13 +61,16 @@ def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH
@property
def _hpc_resume_path(self) -> Optional[str]:
dir_path_hpc = self.trainer.default_root_dir
fs = get_filesystem(dir_path_hpc)
if not fs.isdir(dir_path_hpc):
return None
dir_path_hpc = str(dir_path_hpc)
fs, path = url_to_fs(dir_path_hpc)
if not fs.isdir(path):
return None
max_version = self.__max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
if max_version is not None:
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")
if isinstance(fs, LocalFileSystem):
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")
else:
return dir_path_hpc + fs.sep + f"hpc_ckpt_{max_version}.ckpt"

def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
Expand Down Expand Up @@ -574,12 +579,12 @@ def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> Op
"""

# check directory existence
fs = get_filesystem(dir_path)
fs, uri = url_to_fs(str(dir_path))
if not fs.exists(dir_path):
return None

# check corresponding file existence
files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)]
files = [os.path.basename(f["name"]) for f in fs.listdir(uri)]
files = [x for x in files if name_key in x]
if len(files) == 0:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,30 @@ def test_hpc_max_ckpt_version(tmpdir):
)


def test_ckpt_for_fsspec():
"""Test that the CheckpointConnector is able to write to fsspec file systems."""

model = BoringModel()
# hardcoding dir since `tmpdir` can be windows path
trainer = Trainer(
default_root_dir="memory://test_ckpt_for_fsspec", limit_train_batches=1, limit_val_batches=1, max_epochs=1
)
trainer.fit(model)
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt.ckpt")
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_0.ckpt")
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_3.ckpt")
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_33.ckpt")

assert trainer._checkpoint_connector._hpc_resume_path == "memory://test_ckpt_for_fsspec/hpc_ckpt_33.ckpt"
assert (
trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder("memory://test_ckpt_for_fsspec")
== 33
)
assert (
trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder("memory://not_existing") is None
)


def test_loops_restore(tmpdir):
"""Test that required loop state_dict is loaded correctly by checkpoint connector."""
model = BoringModel()
Expand Down

0 comments on commit c8ba921

Please sign in to comment.