Skip to content

Commit

Permalink
Added multi-epoch training capability (#73)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Egert <degert@nvidia.com>

Added multi-epoch training capability for RM, PPO, DPO, and SFT
  • Loading branch information
trias702 committed Jan 28, 2024
1 parent 003b655 commit 8399e5a
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 96 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- PPO: memory optimization to help avoid OOM in the actor when sending training data to the critic.
- PPO: it is now possible to use a custom end string in `sampling_params.end_strings` that is different from `<extra_id_1>`.
- SFT: added support for custom validation metrics based on model generations.
- Added the ability to do multi-epoch (cfg.max_epochs > 1) training for reward models, DPO, PPO, and SFT

### Breaking changes
- We have changed the shuffle logic in the data sampler to support multi-epoch training, so training runs using identical parameters
will not give the same results anymore because the shuffle logic has changed (specifically the seed value is modified slightly per epoch).
If you run CI/regression type tests, then be warned that the test may break due to this shuffle change.

### Bug Fixes
- Fixed a potential issue when the base model's `model.data.data_prefix` config is a list and is about to be overridden with
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/gpt/conf/gpt_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ exp_manager:
mode: min
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits
filename: 'megatron_gpt--{val_loss:.3f}-{step}-{consumed_samples}-{epoch}'
filename: 'megatron_gpt--{${.monitor}:.3f}-{step}-{consumed_samples}-{epoch}'
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}

pretrained_checkpoint:
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/gpt/conf/gpt_ppo_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ trainer:
precision: bf16

ppo:
max_epochs: 1
max_steps: -1 # max PPO steps (-1 to go through the whole train set)
val_check_interval: 10
save_interval: ${.val_check_interval}
Expand All @@ -23,7 +24,6 @@ trainer:

# pick up from the model
# *do not change this*
max_epochs: 1 # anything above 1 not supported
model_gbs: ${model.global_batch_size}
model_mbs: ${model.micro_batch_size}

Expand Down
4 changes: 2 additions & 2 deletions examples/nlp/gpt/conf/gpt_sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ exp_manager:
monitor: val_loss
save_top_k: 5
mode: min
save_nemo_on_train_end: False
filename: 'megatron_gpt_sft--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}'
save_nemo_on_train_end: False
filename: 'megatron_gpt_sft--{${.monitor}:.3f}-{step}-{consumed_samples}-{epoch}'
model_parallel_size: ${model.tensor_model_parallel_size}
save_best_model: False # need to keep this false otherwise it will create multiple last.ckpt files because restore reset the previous best model

Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/gpt/conf/training_rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ exp_manager:
mode: min
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits
filename: 'megatron_gpt--{val_loss:.3f}-{step}-{consumed_samples}-{epoch}'
filename: 'megatron_gpt--{${.monitor}:.3f}-{step}-{consumed_samples}-{epoch}'
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}

pretrained_checkpoint:
Expand Down
28 changes: 16 additions & 12 deletions nemo_aligner/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from nemo.utils import logging
from nemo_aligner.utils.distributed import SyncTimer
from nemo_aligner.utils.train_utils import clip_gradients
from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches
from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch
from nemo_aligner.utils.utils import clear_memory


Expand Down Expand Up @@ -95,13 +95,13 @@ def __init__(
self.run_timer = run_timer

self.step = 0
self.epoch = 0
self.consumed_samples = 0

self.ckpt_callback = ckpt_callback

# used to compute the max step
self._train_dataloader_len = len(train_dataloader)
# compute `max_steps`
self.num_steps_per_epoch = compute_num_steps_per_epoch(self.train_dataloader.batch_sampler)

self.limit_val_batches = compute_limit_batches(len(val_dataloader), self.cfg.limit_val_batches)
self.val_check_interval = (
int(self.cfg.val_check_interval * self._train_dataloader_len)
Expand Down Expand Up @@ -196,10 +196,12 @@ def fit(self):
self.run_timer.start_time()

for _ in epoch_iter:
loop_iter = range(self.step, self.max_steps)
num_steps_in_epoch = min(
self.max_steps - self.step, self.num_steps_per_epoch - self.step % self.num_steps_per_epoch
)
loop_iter = range(num_steps_in_epoch)

# TODO(geshen): to change for when we support > 1 epoch
if len(loop_iter) <= 0:
if not loop_iter:
return # training ended

global_pbar = tqdm(
Expand All @@ -223,6 +225,7 @@ def fit(self):
self.consumed_samples += self.model.cfg.global_batch_size
metrics["consumed_samples"] = self.consumed_samples
metrics["step_time"] = train_step_time
metrics["epoch"] = self.epoch + 1
self.logger.log_metrics(
metrics, step=self.step, prefix="train/",
)
Expand Down Expand Up @@ -261,8 +264,6 @@ def fit(self):

metrics.clear()

self.epoch += 1

self.logger.finalize()

def save(self, extra_candidates=None, is_train_end=False):
Expand All @@ -278,7 +279,7 @@ def save(self, extra_candidates=None, is_train_end=False):
self.ckpt_callback.custom_save(monitor_candidates=monitor_candidates, is_train_end=is_train_end)

def set_max_steps(self):
self.max_steps = self._train_dataloader_len + self.step
self.max_steps = self.num_steps_per_epoch * self.cfg.max_epochs

if (max_steps := self.cfg.get("max_steps", -1)) >= 0:
self.max_steps = min(self.max_steps, max_steps)
Expand All @@ -293,9 +294,8 @@ def state_dict(self):
def load_state_dict(self, state_dict):
self.step = state_dict["step"]
self.consumed_samples = state_dict["consumed_samples"]
self.epoch = state_dict["epoch"]

loaded_values = [self.step, self.consumed_samples, self.epoch]
loaded_values = [self.step, self.consumed_samples]

# make sure everyone loaded the same checkpoint as rank 0
to_broadcast = torch.tensor(loaded_values, dtype=torch.float32, device=torch.cuda.current_device())
Expand Down Expand Up @@ -329,3 +329,7 @@ def augment_dataloader(self, dataloader):
yield batch
buffer.clear()
del logprobs

@property
def epoch(self):
return self.step // self.num_steps_per_epoch
56 changes: 29 additions & 27 deletions nemo_aligner/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split
from nemo.utils import logging
from nemo_aligner.data.nlp.samplers import MegatronPretrainingRandomSampler
from nemo_aligner.utils.distributed import (
SyncTimer,
masked_global_mean_var,
Expand All @@ -37,7 +38,7 @@
)
from nemo_aligner.utils.server_utils import FutureResult
from nemo_aligner.utils.train_utils import clip_gradients
from nemo_aligner.utils.trainer_utils import check_progress
from nemo_aligner.utils.trainer_utils import check_progress, compute_num_steps_per_epoch
from nemo_aligner.utils.utils import clear_memory, cpu_dict, masked_mean


Expand Down Expand Up @@ -79,14 +80,13 @@ def __init__(
self.run_timer = run_timer

self.consumed_samples = 0
self.epoch = 0
# the step here is PPO step
self.step = 0
# keep track of how many times we optimized the actor
self.ppo_optimization_step = 0

# used to compute the max step
self._train_dataloader_len = len(train_dataloader)
# compute `max_steps`
self.num_steps_per_epoch = compute_num_steps_per_epoch(self.train_dataloader.batch_sampler)
self.set_max_steps()

self.compute_init_policy_kl = self.cfg.initial_policy_kl_penalty > 0
Expand Down Expand Up @@ -348,21 +348,29 @@ def run_training(self, dataloader_iter):
return loss_mean, metrics

def fit(self):
if self.cfg.max_epochs is not None and self.cfg.max_epochs > 1:
# because we need to figure out a nice way to reset the shuffling on our dataset
# otherwise epoch > 1 will loop over the dataset in the same order
raise ValueError("epoch > 1 is not supported")
if (not isinstance(self.train_dataloader.batch_sampler, MegatronPretrainingRandomSampler)) and (
self.cfg.max_epochs is not None and self.cfg.max_epochs > 1
):
# if you use MegatronPretrainingBatchSampler as the batch_sampler passed to your train dataloader (in builders.py)
# then each epoch will repeat all your samples in the same order as the previous epoch, there is no shuffling
# to fix this, you should use MegatronPretrainingRandomSampler instead, which alleviates this issue and allows
# random shuffling for each epoch.
raise ValueError(
"max_epochs > 1 is not supported unless using `MegatronPretrainingRandomSampler` as the batch_sampler for your train dataloader"
)

epoch_iter = range(self.epoch, self.cfg.max_epochs)
if len(epoch_iter) <= 0:
# epoch done
return

for _ in epoch_iter:
loop_iter = range(self.step, self.max_steps)
num_steps_in_epoch = min(
self.max_steps - self.step, self.num_steps_per_epoch - self.step % self.num_steps_per_epoch
)
loop_iter = range(num_steps_in_epoch)

# TODO(geshen): to change for when we support > 1 epoch
if len(loop_iter) <= 0:
if not loop_iter:
return # training ended

dataloader_iter = iter(self.train_dataloader)
Expand Down Expand Up @@ -396,6 +404,7 @@ def fit(self):
table_metrics["response"],
table_metrics["reward"],
]
metrics["epoch"] = self.epoch + 1
self.logger.log_metrics(
metrics, step=self.step, prefix="train_rollouts/",
)
Expand Down Expand Up @@ -459,8 +468,6 @@ def fit(self):
logging.info(f"Time limit given by run_timer={self.run_timer} reached. Stopping run")
return

self.epoch += 1

self.logger.finalize()

def state_dict(self):
Expand All @@ -474,10 +481,9 @@ def state_dict(self):
def load_state_dict(self, state_dict):
self.step = state_dict["step"]
self.consumed_samples = state_dict["consumed_samples"]
self.epoch = state_dict["epoch"]
self.ppo_optimization_step = state_dict["ppo_optimization_step"]

loaded_values = [self.step, self.consumed_samples, self.epoch, self.ppo_optimization_step]
loaded_values = [self.step, self.consumed_samples, self.ppo_optimization_step]

# make sure everyone loaded the same checkpoint as rank 0
to_broadcast = torch.tensor(loaded_values, dtype=torch.float32, device=torch.cuda.current_device())
Expand Down Expand Up @@ -508,15 +514,11 @@ def save(self, extra_candidates=None, is_train_end=False):
self.model.finish_training()

def set_max_steps(self):
max_steps = self.cfg.get("max_steps", -1)

if max_steps == -1:
# the dataloader already knows how much longer
# because consumed samples is resumed
max_steps = self._train_dataloader_len
else:
# user specified the max step, figure out how much longer
# we need to run for
max_steps = max_steps - self.step

self.max_steps = min(max_steps, self._train_dataloader_len) + self.step
self.max_steps = self.num_steps_per_epoch * self.cfg.max_epochs

if (max_steps := self.cfg.get("max_steps", -1)) >= 0:
self.max_steps = min(self.max_steps, max_steps)

@property
def epoch(self):
return self.step // self.num_steps_per_epoch
56 changes: 30 additions & 26 deletions nemo_aligner/algorithms/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm

from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import (
MegatronPretrainingRandomBatchSampler,
)
from nemo.utils import logging
from nemo_aligner.metrics import InferenceMetricsHandler
from nemo_aligner.utils.distributed import SyncTimer
from nemo_aligner.utils.train_utils import clip_gradients
from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches
from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch


class SupervisedTrainer:
Expand Down Expand Up @@ -57,13 +60,12 @@ def __init__(
self.run_timer = run_timer

self.step = 0
self.epoch = 0
self.consumed_samples = 0

self.ckpt_callback = ckpt_callback

# used to compute the max step
self._train_dataloader_len = len(train_dataloader)
# compute `max_steps`
self.num_steps_per_epoch = compute_num_steps_per_epoch(self.train_dataloader.batch_sampler)

self.limit_val_batches = compute_limit_batches(len(val_dataloader), self.cfg.limit_val_batches)
self.set_max_steps()
Expand Down Expand Up @@ -148,10 +150,16 @@ def run_generation(self, batch):
return self.model.infer({"text": batch["contexts"], "length": batch["context_lengths"]})

def fit(self):
if self.cfg.max_epochs is not None and self.cfg.max_epochs > 1:
# because we need to figure out a nice way to reset the shuffling on our dataset
# otherwise epoch > 1 will loop over the dataset in the same order
raise ValueError("epoch > 1 is not supported")
if (not isinstance(self.train_dataloader.batch_sampler, MegatronPretrainingRandomBatchSampler)) and (
self.cfg.max_epochs is not None and self.cfg.max_epochs > 1
):
# if you use MegatronPretrainingBatchSampler as the batch_sampler passed to your train dataloader (in builders.py)
# then each epoch will repeat all your samples in the same order as the previous epoch, there is no shuffling
# to fix this, you should use MegatronPretrainingRandomBatchSampler instead, which alleviates this issue and allows
# random shuffling for each epoch.
raise ValueError(
"max_epochs > 1 is not supported unless using `MegatronPretrainingRandomBatchSampler` as the batch_sampler for your train dataloader"
)

epoch_iter = range(self.epoch, self.cfg.max_epochs)
if len(epoch_iter) <= 0:
Expand All @@ -161,10 +169,12 @@ def fit(self):
self.run_timer.start_time()

for _ in epoch_iter:
loop_iter = range(self.step, self.max_steps)
num_steps_in_epoch = min(
self.max_steps - self.step, self.num_steps_per_epoch - self.step % self.num_steps_per_epoch
)
loop_iter = range(num_steps_in_epoch)

# TODO(geshen): to change for when we support > 1 epoch
if len(loop_iter) <= 0:
if not loop_iter:
return # training ended

global_pbar = tqdm(
Expand All @@ -182,6 +192,7 @@ def fit(self):
self.consumed_samples += self.model.cfg.global_batch_size
metrics["consumed_samples"] = self.consumed_samples
metrics["step_time"] = train_step_time
metrics["epoch"] = self.epoch + 1
self.logger.log_metrics(
metrics, step=self.step, prefix="train/",
)
Expand Down Expand Up @@ -220,8 +231,6 @@ def fit(self):

metrics.clear()

self.epoch += 1

self.logger.finalize()

def save(self, extra_candidates=None, is_train_end=False):
Expand All @@ -237,18 +246,10 @@ def save(self, extra_candidates=None, is_train_end=False):
self.ckpt_callback.custom_save(monitor_candidates=monitor_candidates, is_train_end=is_train_end)

def set_max_steps(self):
max_steps = self.cfg.get("max_steps", -1)
self.max_steps = self.num_steps_per_epoch * self.cfg.max_epochs

if max_steps == -1:
# the dataloader already knows how much longer
# because consumed samples is resumed
max_steps = self._train_dataloader_len
else:
# user specified the max step, figure out how much longer
# we need to run for
max_steps = max_steps - self.step

self.max_steps = min(max_steps, self._train_dataloader_len) + self.step
if (max_steps := self.cfg.get("max_steps", -1)) >= 0:
self.max_steps = min(self.max_steps, max_steps)

def state_dict(self):
return {
Expand All @@ -260,9 +261,8 @@ def state_dict(self):
def load_state_dict(self, state_dict):
self.step = state_dict["step"]
self.consumed_samples = state_dict["consumed_samples"]
self.epoch = state_dict["epoch"]

loaded_values = [self.step, self.consumed_samples, self.epoch]
loaded_values = [self.step, self.consumed_samples]

# make sure everyone loaded the same checkpoint as rank 0
to_broadcast = torch.tensor(loaded_values, dtype=torch.float32, device=torch.cuda.current_device())
Expand All @@ -271,3 +271,7 @@ def load_state_dict(self, state_dict):
assert loaded_values == to_broadcast.tolist()
# restore max steps we need to run for
self.set_max_steps()

@property
def epoch(self):
return self.step // self.num_steps_per_epoch
Loading

0 comments on commit 8399e5a

Please sign in to comment.