Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async callback: Don't skip checkpoints, reliably only launch async eval when the checkpoint is ready #813

Merged
merged 29 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 240 additions & 66 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@

import logging
import os
import warnings
from collections import Counter
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from composer.callbacks import CheckpointSaver
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core import Callback, Event, State, Time, Timestamp, TimeUnit
from composer.loggers import Logger
from composer.loggers.mosaicml_logger import (MOSAICML_PLATFORM_ENV_VAR,
RUN_NAME_ENV_VAR)
from composer.utils import dist
from composer.utils.file_helpers import list_remote_objects
from composer.utils.misc import create_interval_scheduler

from mcli import Run, RunConfig, create_run, get_run
Expand Down Expand Up @@ -73,34 +76,6 @@ def get_run_name(training_run_name: str, current_interval: str) -> str:
return f'{RUN_NAME_PREFIX}-{current_interval}-{name_without_uuid_suffix}'


def get_latest_checkpoint(event: Event, state: State) -> Optional[str]:
"""Get the latest checkpoint from the training run.

Args:
event: The current run event
state: The current state of the training run

Returns:
The path to the latest checkpoint, or None if there is not a latest checkpoint
"""
checkpointer = None
for callback in state.callbacks:
if isinstance(callback, CheckpointSaver):
checkpointer = callback
break

if not checkpointer:
log.warning('No checkpoint saver callback found')
return None

if not checkpointer.saved_checkpoints:
log.warning('No saved checkpoints found on the checkpointer')
return None

latest = checkpointer.saved_checkpoints[-1]
return str(Path(latest).parts[-1])


def get_eval_parameters(
parameters: Dict[str, Any],
checkpoint: str,
Expand Down Expand Up @@ -199,6 +174,9 @@ def validate_eval_run_config(
return run_config


CHECKS_PER_INTERVAL = 4


class AsyncEval(Callback):
"""Run the eval loop asynchronously as part of a MosaicML platform run.

Expand Down Expand Up @@ -234,76 +212,263 @@ def __init__(
eval_run_config: Optional[Dict[str, Any]] = None,
):

# Run these during init to fail fast in any of the error cases
for required in ('save_interval', 'save_folder'):
if required not in training_params:
raise ValueError(f'{required} required for async eval')

if '/' in training_params.get('save_filename', ''):
raise ValueError(
'AsyncEval not supported for save_filename that includes a path'
)

self.checkpoint_save_folder = training_params['save_folder']
self.training_params = training_params
self.eval_run_config = validate_eval_run_config(eval_run_config)
self.interval = validate_interval(interval,
self.training_params['save_interval'])
self.check_interval = create_interval_scheduler(
interval,
# There is a custom close to ensure that the final checkpoint
# (which is the most important) is evaled after it is written
include_end_of_training=False,
)
self.last_checkpoint: Optional[str] = None

# Run these during init to fail fast in any of the error cases
self.current_run = self._get_current_run()
get_eval_parameters(
parameters=training_params,
checkpoint='test',
training_run_name=self.current_run.name,
)
log.info(
f'Initialized AsyncEval callback. Will generate runs at interval {interval}'

# Validate the interval (how often to launch eval runs)
self.interval = validate_interval(interval,
self.training_params['save_interval'])

# Configures how often to check for new checkpoints. This is semi-arbitrary;
# really we just want to check often enough to pull relevant checkpoints
# but not so often that we're constantly checking
check_interval_value = max(self.interval.value // CHECKS_PER_INTERVAL,
1)
self.check_interval = Time(check_interval_value, self.interval.unit)

# Keep track of checkpoints that have already been evaled
# Format: {eval_timestamp: (checkpoint, run_name)}
self.checkpoints_evaled: Dict[Time, Tuple[str, str]] = {}

# Scheduling is based on the check interval, while _get_checkpoints_and_launch_runs
# will only launch runs at the interval
self.is_at_check_interval = create_interval_scheduler(
self.check_interval,
# There is a custom close to ensure that the final checkpoint
# (which is the most important) is evaled after it is written
include_end_of_training=False,
)

log.info('Initialized AsyncEval callback. Will generate runs at ' +
f'interval {interval}, checking at {self.check_interval}')

def state_dict(self) -> Dict[str, Any]:
checkpoints_evaled = []
for eval_ts, (checkpoint, run_name) in self.checkpoints_evaled.items():
eval_ts_dict = {
'value': eval_ts.value,
'unit': eval_ts.unit.value,
}
checkpoints_evaled.append((eval_ts_dict, checkpoint, run_name))

return {
'checkpoints_evaled': checkpoints_evaled,
}

def load_state_dict(self, state_dict: Dict[str, Any]):
previous_checkpoints_evaled = state_dict.get('checkpoints_evaled', [])
if previous_checkpoints_evaled:
for (eval_ts, checkpoint, run_name) in previous_checkpoints_evaled:
eval_ts = Time(eval_ts['value'], TimeUnit(eval_ts['unit']))
self.checkpoints_evaled[eval_ts] = (checkpoint, run_name)

log.info(
f'Loaded previous checkpoints evaled: {self.checkpoints_evaled}'
)

@staticmethod
def _get_ready_sharded_checkpoints(
checkpointer_checkpoints: Dict[str, Timestamp],
remote_files: List[str],
) -> Dict[str, Timestamp]:
"""Identify checkpoints ready to be evaled based on remote files.

This has special logic for sharded checkpoints to consider checkpoints composed
of multiple shards (one per gpu) and metadata

Args:
checkpointer_checkpoints: All checkpoints from the checkpointer state
remote_files: List of remote files in the save folder

Returns:
Dict of checkpoints that are complete and ready to be evaled
"""
# Count the number of shards for each checkpoint group
remote_file_group_counts = Counter()
for f in remote_files:
checkpoint_ts_path = Path(f).parts[-2]
remote_file_group_counts[checkpoint_ts_path] += 1

# Check if all shards are present for each checkpoint group
checkpoints_to_eval = {}
for checkpoint, checkpoint_ts in checkpointer_checkpoints.items():
# eg {save_folder}/ep0-ba1/file.blah.
checkpoint_ts_path = Path(checkpoint).parts[-2]

# expecting one shard per gpu + 1 for metadata
expected_shard_count = dist.get_world_size() + 1
if remote_file_group_counts[
checkpoint_ts_path] != expected_shard_count:
log.debug(
f'Checkpoint {checkpoint} not fully uploaded (missing shards '
+
f'{remote_file_group_counts[checkpoint_ts_path]}/{expected_shard_count}), skipping'
)
continue

checkpoints_to_eval[checkpoint_ts_path] = checkpoint_ts

return checkpoints_to_eval

@staticmethod
def _get_ready_single_checkpoints(
checkpointer_checkpoints: Dict[str, Timestamp],
remote_checkpoints: List[str],
) -> Dict[str, Timestamp]:
"""Identify checkpoints ready to be evaled based on remote checkpoints.

This is much simpler than the sharded case, because there is only one file

Args:
checkpointer_checkpoints: All checkpoints from the checkpointer state
remote_checkpoints: List of remote checkpoints in the save folder

Returns:
Dict of checkpoints that are complete and ready to be evaled
"""
unique_remote_checkpoints = set(remote_checkpoints)

checkpoints_to_eval = {}
for checkpoint, checkpoint_ts in checkpointer_checkpoints.items():
# This assumes checkpoint_ts_path is unique per checkpoint,
# eg the default {save_folder}/ep0-ba1-rank0.pt
checkpoint_ts_path = Path(checkpoint).parts[-1]

if checkpoint not in unique_remote_checkpoints:
log.debug(
f'Checkpoint {checkpoint} not fully uploaded, skipping')
continue

checkpoints_to_eval[checkpoint_ts_path] = checkpoint_ts
return checkpoints_to_eval

def _get_checkpoints_and_launch_runs(self, state: State):
"""Get the latest checkpoint from the training run.

Args:
state: The current state of the training run

Returns:
Returns checkpoints that have not been evaled
"""
checkpointer = None
for callback in state.callbacks:
if isinstance(callback, CheckpointSaver):
if checkpointer is None:
checkpointer = callback
else:
log.warning(
'Multiple checkpoint savers found. Using the first one')

if not checkpointer:
warnings.warn('No checkpoint saver callback found. Skipping eval')
return

if not checkpointer.all_saved_checkpoints_to_timestamp:
log.debug(
'No saved checkpoints found on the checkpointer. Skipping eval')
return

log.debug(
f'Found {len(checkpointer.all_saved_checkpoints_to_timestamp)} ' +
f'checkpoints: {checkpointer.all_saved_checkpoints_to_timestamp}')

remote_checkpoints = list_remote_objects(self.checkpoint_save_folder)

if not remote_checkpoints:
log.debug('No saved checkpoints found yet on remote. Skipping eval')
return

if state.fsdp_elastic_sharded_enabled:
checkpoints_to_eval = self._get_ready_sharded_checkpoints(
checkpointer.all_saved_checkpoints_to_timestamp,
remote_checkpoints)
else:
checkpoints_to_eval = self._get_ready_single_checkpoints(
checkpointer.all_saved_checkpoints_to_timestamp,
remote_checkpoints)

for checkpoint_interval_path, checkpoint_timestamp in checkpoints_to_eval.items(
):
checkpoint_ts = checkpoint_timestamp.get(self.interval.unit)
if checkpoint_ts.value % self.interval.value != 0:
log.debug(
f'Checkpoint {checkpoint_interval_path} ({checkpoint_ts}) is '
+ f'not at an eval interval ({self.interval}), skipping')
continue
if checkpoint_ts in self.checkpoints_evaled:
continue # Skip checkpoints that have already been evaled

full_checkpoint_path = f'{self.checkpoint_save_folder}/{checkpoint_interval_path}'
eval_run = self.launch_run(full_checkpoint_path, checkpoint_ts)
self.checkpoints_evaled[checkpoint_ts] = (
full_checkpoint_path,
eval_run.name,
)

def run_event(self, event: Event, state: State, logger: Logger) -> None:
del logger

should_launch_run = all([
state.get_elapsed_duration() is not None,
self.check_interval(state, event),
# could also skip check intervals before the first async eval interval,
# but this may make the scheduler more complicated
self.is_at_check_interval(state, event),
dist.get_global_rank() == 0,
])

if should_launch_run:
current_interval = state.timestamp.get(self.interval.unit)
checkpoint = get_latest_checkpoint(event, state)
if not checkpoint:
return # warnings logged in get_latest_checkpoint

# TODO: ensure the checkpoint is fully written before launching the eval run
full_checkpoint = f'{self.checkpoint_save_folder}/{checkpoint}'
if full_checkpoint == self.last_checkpoint:
# Do not eval a checkpoint that has already been evaluated.
log.info(
'Skipping async eval because the checkpoint has not changed'
)
return

self.launch_run(full_checkpoint, current_interval)
self.last_checkpoint = full_checkpoint
self._get_checkpoints_and_launch_runs(state)

def close(self, state: State, logger: Logger) -> None:
del logger

if dist.get_global_rank() != 0:
return

save_latest_filename = self.training_params.get('save_latest_filename',
None)
# Eval any remaining checkpoints
self._get_checkpoints_and_launch_runs(state)

# Eval the latest checkpoint
latest_timestamp = state.timestamp.get(self.interval.unit)
if latest_timestamp not in self.checkpoints_evaled:
save_latest_filename = self.training_params.get(
'save_latest_filename', None)

if not save_latest_filename:
rank = dist.get_global_rank()
save_latest_filename = f'latest-rank{rank}.pt'

checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}'

if not save_latest_filename:
rank = dist.get_global_rank()
save_latest_filename = f'latest-rank{rank}.pt'
eval_run = self.launch_run(checkpoint, latest_timestamp)
self.checkpoints_evaled[latest_timestamp] = (checkpoint,
eval_run.name)

checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}'
self.launch_run(checkpoint, state.timestamp.get(self.interval.unit))
log.info(
f'AsyncEval callback finished. Launched {len(self.checkpoints_evaled)} eval runs:'
)
for checkpoint_ts, (checkpoint,
run_name) in self.checkpoints_evaled.items():
log.info(f' {checkpoint_ts}: {checkpoint}, {run_name}')

def _get_current_run(self) -> Run:
if os.environ.get(MOSAICML_PLATFORM_ENV_VAR,
Expand All @@ -322,6 +487,15 @@ def _get_current_run(self) -> Run:
return get_run(run_name, include_details=True)

def launch_run(self, checkpoint: str, current_interval: Time) -> Run:
"""Launch a new eval run.

Args:
checkpoint: The checkpoint to eval
current_interval: The interval of the checkpoint

Returns:
The launched run (mcli.Run type)
"""
log.info(f'Launching eval run for {checkpoint} at {current_interval}')

cfg = self.current_run.submitted_config
Expand Down
Loading
Loading