Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed May 2, 2024
1 parent 2e074b7 commit 8c564e8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
7 changes: 4 additions & 3 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME = f'__{dist.get_global_rank()}_0.distcp'


def _get_checkpoint_validation_function() -> Optional[Callable[[Union[Path, str], Optional[List[Tuple[int,int]]]], bool]]:
def _get_checkpoint_validation_function(
) -> Optional[Callable[[Union[Path, str], Optional[List[Tuple[int, int]]]], bool]]:
"""Get the validation function by name.
Args:
Expand Down Expand Up @@ -99,9 +100,9 @@ def _ensure_valid_checkpoint(checkpoint_filepath: Union[Path, str],

# Validate the checkpoint.
if not validate(checkpoint_filepath, specs):
raise ValueError(f'Checkpoint at {checkpoint_filepath} is invalid.')
raise ValueError(f'Checkpoint at {checkpoint_filepath} {specs=}is invalid.')

log.debug(f'Checkpoint at {checkpoint_filepath} is valid.')
log.debug(f'Checkpoint at {checkpoint_filepath} {specs=} is valid.')
return checkpoint_filepath


Expand Down
18 changes: 16 additions & 2 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tempfile
import time
from glob import glob
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -1766,7 +1766,14 @@ def test_rotate_checkpoints(
dist.barrier() # all ranks finish before cleaning up tmpdir


def simple_validate(filepath: str):
def simple_validate(filepath: str, specs: Optional[List[Tuple[int, int]]] = None) -> bool:
if specs is not None:
with open(filepath, 'r') as f:
for offset, length in specs:
f.seek(offset)
if f.read(length) != 'good':
return False
return True
with open(filepath, 'r') as f:
return f.read() == 'good'

Expand Down Expand Up @@ -1795,6 +1802,13 @@ def test_checkpoint_validation(tmp_path):
result = _ensure_valid_checkpoint(checkpoint_filepath)
assert result == checkpoint_filepath

# Correct usage with offset and lengths and successful validation.
with open(checkpoint_filepath, 'w') as f:
f.write('good good')
with patch.dict(os.environ, {'CHECKPOINT_VALIDATION_FUNCTION': 'tests.trainer.test_checkpoint.simple_validate'}):
result = _ensure_valid_checkpoint(checkpoint_filepath, specs=[(0, 4), (5, 4)])
assert result == checkpoint_filepath

# Correct usage and failed validation.
with open(checkpoint_filepath, 'w') as f:
f.write('bad')
Expand Down

0 comments on commit 8c564e8

Please sign in to comment.