Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add offset and length arguments for checkpoint validation functions #3246

Merged
merged 10 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import warnings
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from packaging import version
Expand Down Expand Up @@ -54,16 +54,13 @@
_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME = f'__{dist.get_global_rank()}_0.distcp'


def _get_checkpoint_validation_function() -> Optional[Callable[[Union[Path, str]], bool]]:
"""Get the validation function by name.

Args:
name (str): Qualified name of the checkpoint validation function.
It should be in the form '{module_name}.{fn_name}'.
def _get_checkpoint_validation_function(
) -> Optional[Callable[[Union[Path, str], Optional[List[Tuple[int, int]]]], bool]]:
"""Get the validation function specified by the environment variable `CHECKPOINT_VALIDATION_FUNCTION`.

Returns:
Callable[[Union[Path, str]], bool] The checkpoint validation function that returns
True given a valid checkpoint and False otherwise.
Callable[[Union[Path, str], Optional[int], Optional[int]], bool] The checkpoint validation function that returns
True given a valid checkpoint and optionally a list of offsets and lengths to check and False otherwise.
"""
name = os.environ.get('CHECKPOINT_VALIDATION_FUNCTION', None)
if name is None:
Expand All @@ -76,14 +73,16 @@ def _get_checkpoint_validation_function() -> Optional[Callable[[Union[Path, str]
return fn


def _ensure_valid_checkpoint(checkpoint_filepath: Union[Path, str]) -> Union[Path, str]:
def _ensure_valid_checkpoint(checkpoint_filepath: Union[Path, str],
specs: Optional[List[Tuple[int, int]]] = None) -> Union[Path, str]:
"""Ensures that the checkpoint at checkpoint_filepath is valid.

using the function specified by the CHECKPOINT_VALIDATION_FUNCTION environment variable.
If CHECKPOINT_VALIDATION_FUNCTION is not set, we skip validation.

Args:
checkpoint_filepath (Union[Path,str]): The path to the checkpoint file.
specs (Optional[List[Tuple[int,int]]]): A list of offsets and lengths to check. Defaults to None.

Raises:
ValueError if checkpoint file is invalid.
Expand All @@ -93,11 +92,10 @@ def _ensure_valid_checkpoint(checkpoint_filepath: Union[Path, str]) -> Union[Pat

# No function name has been specified.
if validate is None:
log.debug('No validation function specified. Skipping checkpoint validation.')
return checkpoint_filepath

# Validate the checkpoint.
if not validate(checkpoint_filepath):
if not validate(checkpoint_filepath, specs):
raise ValueError(f'Checkpoint at {checkpoint_filepath} is invalid.')

log.debug(f'Checkpoint at {checkpoint_filepath} is valid.')
Expand Down Expand Up @@ -169,13 +167,13 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
Raises:
ValueError if the data file is invalid.
"""
validated_checkpoint_paths = set()
path_to_specs: Dict[str, List[Tuple[int, int]]] = {}
for read_item in plan.items:
data_path = os.path.join(self.path, self.storage_data[read_item.storage_index].relative_path)
if data_path in validated_checkpoint_paths:
continue
_ensure_valid_checkpoint(data_path)
validated_checkpoint_paths.add(data_path)
item_md = self.storage_data[read_item.storage_index]
path = os.path.join(self.path, item_md.relative_path)
path_to_specs.setdefault(path, []).append((item_md.offset, item_md.length))
for path, spec in path_to_specs.items():
_ensure_valid_checkpoint(path, spec)
return super().read_data(plan, planner)

def read_metadata(self) -> Metadata:
Expand Down
18 changes: 16 additions & 2 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tempfile
import time
from glob import glob
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -1766,7 +1766,14 @@ def test_rotate_checkpoints(
dist.barrier() # all ranks finish before cleaning up tmpdir


def simple_validate(filepath: str):
def simple_validate(filepath: str, specs: Optional[List[Tuple[int, int]]] = None) -> bool:
if specs is not None:
with open(filepath, 'r') as f:
for offset, length in specs:
f.seek(offset)
if f.read(length) != 'good':
return False
return True
with open(filepath, 'r') as f:
return f.read() == 'good'

Expand Down Expand Up @@ -1795,6 +1802,13 @@ def test_checkpoint_validation(tmp_path):
result = _ensure_valid_checkpoint(checkpoint_filepath)
assert result == checkpoint_filepath

# Correct usage with offset and lengths and successful validation.
with open(checkpoint_filepath, 'w') as f:
f.write('good good')
with patch.dict(os.environ, {'CHECKPOINT_VALIDATION_FUNCTION': 'tests.trainer.test_checkpoint.simple_validate'}):
result = _ensure_valid_checkpoint(checkpoint_filepath, specs=[(0, 4), (5, 4)])
assert result == checkpoint_filepath

# Correct usage and failed validation.
with open(checkpoint_filepath, 'w') as f:
f.write('bad')
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def test_checkpoint_loading_with_validation(world_size, tmp_path, is_valid_check
expectation = pytest.raises(ValueError)

def mock_get_checkpoint_validation_function():
return lambda _: is_valid_checkpoint
return lambda checkpoint_path, specs: is_valid_checkpoint

tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path))
save_folder = os.path.join(tmp_paths[0], 'checkpoints')
Expand Down
Loading