Skip to content

Commit

Permalink
[integration] Update Ray Tune integration for Ray 2.7 (#26499)
Browse files Browse the repository at this point in the history
* fix tune integration for ray 2.7+

Signed-off-by: Justin Yu <justinvyu@anyscale.com>

* add version check for ray tune backend availability

Signed-off-by: Justin Yu <justinvyu@anyscale.com>

* missing import

Signed-off-by: Justin Yu <justinvyu@anyscale.com>

* pin min version instead

Signed-off-by: Justin Yu <justinvyu@anyscale.com>

* address comments

Signed-off-by: Justin Yu <justinvyu@anyscale.com>

* some fixes

Signed-off-by: Justin Yu <justinvyu@anyscale.com>

* fix unnecessary final checkpoint

Signed-off-by: Justin Yu <justinvyu@anyscale.com>

* fix lint

Signed-off-by: Justin Yu <justinvyu@anyscale.com>

* dep table fix

Signed-off-by: Justin Yu <justinvyu@anyscale.com>

* fix lint

Signed-off-by: Justin Yu <justinvyu@anyscale.com>

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
  • Loading branch information
justinvyu committed Dec 9, 2023
1 parent ffd426e commit 5fa66df
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 54 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
"pytest-timeout",
"pytest-xdist",
"python>=3.8.0",
"ray[tune]",
"ray[tune]>=2.7.0",
"regex!=2019.12.17",
"requests",
"rhoknp>=1.1.0,<1.3.1",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0",
"ray[tune]": "ray[tune]",
"ray[tune]": "ray[tune]>=2.7.0",
"regex": "regex!=2019.12.17",
"requests": "requests",
"rhoknp": "rhoknp>=1.1.0,<1.3.1",
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/hyperparameter_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .integrations import (
is_optuna_available,
is_ray_available,
is_ray_tune_available,
is_sigopt_available,
is_wandb_available,
run_hp_search_optuna,
Expand Down Expand Up @@ -81,7 +81,7 @@ class RayTuneBackend(HyperParamSearchBackendBase):

@staticmethod
def is_available():
return is_ray_available()
return is_ray_tune_available()

def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
Expand Down
54 changes: 26 additions & 28 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ def _objective(trial, checkpoint_dir=None):

def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import ray
import ray.train

def _objective(trial, local_trainer, checkpoint_dir=None):
def _objective(trial: dict, local_trainer):
try:
from transformers.utils.notebook import NotebookProgressCallback

Expand All @@ -246,19 +247,34 @@ def _objective(trial, local_trainer, checkpoint_dir=None):
except ModuleNotFoundError:
pass

checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
checkpoint = os.path.join(checkpoint_dir, subdir)
local_trainer.objective = None
local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)

checkpoint = ray.train.get_checkpoint()
if checkpoint:
# Upon trial resume, the local_trainer's objective gets reset to None.
# If `local_trainer.train` is a noop (training has already reached
# the target number of epochs/steps), then this would
# trigger an unnecessary extra checkpoint at the end of training.
# -> Set the objective to a dummy value upon resume as a workaround.
local_trainer.objective = "objective"

with checkpoint.as_directory() as checkpoint_dir:
checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix()
local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial)
else:
local_trainer.train(trial=trial)

# If there hasn't been any evaluation during the training loop.
if getattr(local_trainer, "objective", None) is None:
metrics = local_trainer.evaluate()
local_trainer.objective = local_trainer.compute_objective(metrics)
local_trainer._tune_save_checkpoint()
ray.tune.report(objective=local_trainer.objective, **metrics, done=True)

metrics.update({"objective": local_trainer.objective, "done": True})

with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
ray.train.report(metrics, checkpoint=checkpoint)

if not trainer._memory_tracker.skip_memory_metrics:
from ..trainer_utils import TrainerMemoryTracker
Expand Down Expand Up @@ -296,28 +312,10 @@ def _objective(trial, local_trainer, checkpoint_dir=None):
from ray.tune import CLIReporter

kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
# `keep_checkpoints_num=0` would disabled checkpointing
trainer.use_tune_checkpoints = True
if kwargs["keep_checkpoints_num"] > 1:
logger.warning(
f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
"Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`."
)

if "scheduler" in kwargs:
from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining

# Check if checkpointing is enabled for PopulationBasedTraining
if isinstance(kwargs["scheduler"], PopulationBasedTraining):
if not trainer.use_tune_checkpoints:
logger.warning(
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
"This means your trials will train from scratch everytime they are exploiting "
"new configurations. Consider enabling checkpointing by passing "
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
)

# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
if isinstance(
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
Expand Down
44 changes: 22 additions & 22 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import re
import shutil
import sys
import tempfile
import time
import warnings
from collections.abc import Mapping
Expand Down Expand Up @@ -595,7 +596,6 @@ def __init__(
# returned to 0 every time flos need to be logged
self.current_flos = 0
self.hp_search_backend = None
self.use_tune_checkpoints = False
default_label_names = find_labels(self.model.__class__)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.can_return_loss = can_return_loss(self.model.__class__)
Expand Down Expand Up @@ -1201,7 +1201,8 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
if self.hp_search_backend is None or trial is None:
return
self.objective = self.compute_objective(metrics.copy())
metrics = metrics.copy()
self.objective = self.compute_objective(metrics)
if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna

Expand All @@ -1211,24 +1212,23 @@ def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], ste
self.callback_handler.on_train_end(self.args, self.state, self.control)
raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune

if self.control.should_save:
self._tune_save_checkpoint()
tune.report(objective=self.objective, **metrics)

def _tune_save_checkpoint(self):
from ray import tune

if not self.use_tune_checkpoints:
return
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir, _internal_call=True)
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
import ray.train

with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
if self.control.should_save:
self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
metrics["objective"] = self.objective
ray.train.report(metrics, checkpoint=checkpoint)

def _tune_save_checkpoint(self, checkpoint_dir: str):
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir, _internal_call=True)
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))

def call_model_init(self, trial=None):
model_init_argcount = number_of_arguments(self.model_init)
Expand Down Expand Up @@ -2004,9 +2004,9 @@ def _get_output_dir(self, trial):
if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
import ray.train

run_id = tune.get_trial_id()
run_id = ray.train.get_context().get_trial_id()
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
run_id = trial.id
elif self.hp_search_backend == HPSearchBackend.WANDB:
Expand Down

0 comments on commit 5fa66df

Please sign in to comment.