Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
leoleoasd committed Oct 30, 2022
1 parent f7aa0c5 commit 482260d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from typing import Any, Dict, Optional

import torch
from fsspec.core import url_to_fs
from torch import Tensor
from torchmetrics import Metric
from fsspec.core import uri_to_fs

import pytorch_lightning as pl
from lightning_lite.plugins.environments.slurm import SLURMEnvironment
Expand Down Expand Up @@ -575,7 +575,7 @@ def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> Op
"""

# check directory existence
fs, uri = uri_to_fs(dir_path)
fs, uri = url_to_fs(dir_path)
if not fs.exists(dir_path):
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import os
from unittest import mock

import fsspec.registry
import pytest
import torch
from fsspec.implementations.arrow import ArrowFSWrapper

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -102,6 +104,38 @@ def test_hpc_max_ckpt_version(tmpdir):
)


def test_max_ckpt_version_for_fsspec(tmpdir):
"""Test that the CheckpointConnector is able to find the hpc checkpoint file with the highest version."""

class MockFileSystem(ArrowFSWrapper):
"""A wrapper on top of the pyarrow.fs.HadoopFileSystem to connect it's interface with fsspec."""

protocol = "mock"

def __init__(self, **kwargs):
from pyarrow.fs import FileSystem

fs = FileSystem.from_uri("mock://")
super().__init__(fs=fs, **kwargs)

fsspec.registry.register_implementation("mock", MockFileSystem)

model = BoringModel()
trainer = Trainer(default_root_dir="mock://" + tmpdir, max_steps=1)
trainer.fit(model)
trainer.save_checkpoint(tmpdir / "hpc_ckpt.ckpt")
trainer.save_checkpoint(tmpdir / "hpc_ckpt_0.ckpt")
trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt")
trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt")

assert trainer._checkpoint_connector._hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt")
assert trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder(tmpdir) == 33
assert (
trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder(tmpdir / "not" / "existing")
is None
)


def test_loops_restore(tmpdir):
"""Test that required loop state_dict is loaded correctly by checkpoint connector."""
model = BoringModel()
Expand Down

0 comments on commit 482260d

Please sign in to comment.