diff --git a/CHANGELOG.md b/CHANGELOG.md index 809b2f941..81888d474 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,9 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Next Version] ### New features and optimizations -- Added public-facing official Dockerfile for NeMo-Aligner -- Memory optimization in PPO that helps avoid OOM in the actor when sending training data to the critic -- SFT: added support for custom validation metrics based on model generations +- Added public-facing official Dockerfile for NeMo-Aligner. +- PPO: memory optimization to help avoid OOM in the actor when sending training data to the critic. +- PPO: it is now possible to use a custom end string in `sampling_params.end_strings` that is different from ``. +- SFT: added support for custom validation metrics based on model generations. ### Breaking changes diff --git a/examples/nlp/gpt/conf/gpt_ppo_actor.yaml b/examples/nlp/gpt/conf/gpt_ppo_actor.yaml index 4b6439d6f..68edc27ce 100644 --- a/examples/nlp/gpt/conf/gpt_ppo_actor.yaml +++ b/examples/nlp/gpt/conf/gpt_ppo_actor.yaml @@ -10,7 +10,7 @@ trainer: ppo: max_steps: -1 # max PPO steps (-1 to go through the whole train set) - val_check_interval: 10 + val_check_interval: 10 save_interval: ${.val_check_interval} gradient_clip_val: 1.0 @@ -49,7 +49,7 @@ remote_critic_rm: # must match the same flag in the critic config combine_rm_and_critic_server: True - # reward model server, specify if + # reward model server, specify if # combine_rm_and_critic server is False reward_model: name: reward_model @@ -57,7 +57,7 @@ remote_critic_rm: port: 5555 critic: - name: + name: train: critic_train infer: critic_infer save: critic_save @@ -146,15 +146,15 @@ model: # miscellaneous seed: 1234 - + optim: name: distributed_fused_adam bucket_cap_mb: 200 overlap_grad_sync: False contiguous_grad_buffer: True lr: 9e-7 - weight_decay: 0.1 - betas: + weight_decay: 0.1 + betas: - 0.9 - 0.98 sched: diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py index 8d76eb18c..1661146a5 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py @@ -32,10 +32,11 @@ from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo_aligner.models.alignable_interface import AlignableGenerativeInterface from nemo_aligner.utils.distributed import ( - broadcast_2d_tensor, + broadcast_2d_tensor_within_pp, calculate_distributed_entropy, from_parallel_logits_to_logprobs, ) +from nemo_aligner.utils.text_generation_utils import TrackLengthGPTModelTextGenerationStrategy from nemo_aligner.utils.train_utils import ( grad_reductions, prepare_for_training_step, @@ -43,13 +44,7 @@ set_sync_funcs, set_train, ) -from nemo_aligner.utils.utils import ( - calculate_dialogue_response_lengths, - configure_batch_sizes, - cpu_weight_swap, - masked_mean, - offload_distributed_adam, -) +from nemo_aligner.utils.utils import configure_batch_sizes, cpu_weight_swap, masked_mean, offload_distributed_adam class MegatronGPTActorModel(MegatronGPTModel, AlignableGenerativeInterface): @@ -65,15 +60,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # sampling parameters for generation self._sampling_params = OmegaConf.to_container(self.cfg.ppo.sampling_params, resolve=True) - # Safety check until https://github.com/NVIDIA/NeMo-Aligner/pull/19 is merged. - valid_end_strings = ["<|endoftext|>", ""] - for end_string in self._sampling_params["end_strings"]: - if end_string not in valid_end_strings: - raise NotImplementedError( - "Currently only '<|endoftext|>' and '' are allowed in `sampling_params.end_strings`, " - f"but found '{end_string}'" - ) - self.to_offload_adam_states = self.cfg.ppo.offload_adam_states self.entropy_bonus = self.cfg.ppo.entropy_bonus self.ratio_eps = self.cfg.ppo.ratio_eps @@ -256,13 +242,9 @@ def get_inference_log_probs(self, response_tokens, forward_micro_batch_size): ) logprobs = torch.cat(logprobs_list) if len(logprobs_list) > 0 else None - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - # broadcast it from last PP stage to everything else - logprobs = broadcast_2d_tensor( - logprobs, - parallel_state.get_pipeline_model_parallel_last_rank(), - parallel_state.get_pipeline_model_parallel_group(), - ) + + # Broadcast it from last PP stage to everything else. + logprobs = broadcast_2d_tensor_within_pp(logprobs) return logprobs @@ -280,19 +262,32 @@ def infer(self, inference_batch): prompt_lengths = inference_batch["length"].cuda(non_blocking=True) inputs = (prompt_tokens, prompt_lengths) + strategy = TrackLengthGPTModelTextGenerationStrategy( + model=self, context_lengths=prompt_lengths, max_length=self._length_params["max_length"] + ) actor_output = self.generate( - inputs=inputs, length_params=self._length_params, sampling_params=self._sampling_params + inputs=inputs, length_params=self._length_params, sampling_params=self._sampling_params, strategy=strategy ) + response_lengths = strategy.get_lengths() + max_response_length = response_lengths.max().item() + response_tokens = torch.cuda.LongTensor(actor_output["token_ids"]) - response_lengths = calculate_dialogue_response_lengths( - tokens=response_tokens, - prompt_lengths=prompt_lengths, - tokenizer=self.tokenizer, - end_strings=self._sampling_params["end_strings"], - max_generation_length=self._length_params["max_length"], - max_sequence_length=self.cfg.encoder_seq_length, - ) + + # Sanity check to validate response length. + if max_response_length != response_tokens.size(1): + # This may actually happen because NeMo does not always stop generation after `max_length` in batch mode + # => `response_tokens` may contain up to `max_length + max_context_length` tokens. + # TODO once NeMo fixes this issue we should be able to always raise an exception when the check above fails, + # and remove the `if` below. + if ( + max_response_length >= response_tokens.size(1) + or response_tokens.size(1) != prompt_lengths.max().item() + self._length_params["max_length"] + ): + raise AssertionError( + f"max response length ({max_response_length}) does not match the size of " + f"`response_tokens` ({response_tokens.size(1)})" + ) # TODO(geshen): get nemo generate to return the unaltered log probs log_probs = self.get_inference_log_probs( diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_reward_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_reward_model.py index 056769086..8b0ed9d76 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_reward_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_reward_model.py @@ -35,7 +35,7 @@ from nemo.utils import AppState, logging from nemo_aligner.models.alignable_interface import Inferrable, SupervisedInterface from nemo_aligner.models.nlp.gpt.gpt_reward_model import GPTRewardModel -from nemo_aligner.utils.distributed import broadcast_2d_tensor, gather_tensor +from nemo_aligner.utils.distributed import broadcast_2d_tensor, broadcast_2d_tensor_within_pp, gather_tensor from nemo_aligner.utils.text_generation_utils import tokenize_batch from nemo_aligner.utils.train_utils import ( finish_validation_step, @@ -415,12 +415,7 @@ def infer( if self.enable_standardization: rewards = (rewards - self.rew_mean) / self.rew_std - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - rewards = broadcast_2d_tensor( - rewards, - parallel_state.get_pipeline_model_parallel_last_rank(), - parallel_state.get_pipeline_model_parallel_group(), - ) + rewards = broadcast_2d_tensor_within_pp(rewards) rewards_list = gather_tensor( rewards, dst=parallel_state.get_data_parallel_src_rank(), group=parallel_state.get_data_parallel_group() diff --git a/nemo_aligner/utils/distributed.py b/nemo_aligner/utils/distributed.py index e5f59df1c..78c860108 100644 --- a/nemo_aligner/utils/distributed.py +++ b/nemo_aligner/utils/distributed.py @@ -69,13 +69,25 @@ def broadcast_2d_tensor(tensor, src, group, dtype=torch.float32): return tensor -def broadcast_2d_tensor_within_mp(tensor): +def broadcast_2d_tensor_within_mp(tensor, dtype=torch.float32): """helper function to broadcast within the model parallel group """ group = parallel_state.get_model_parallel_group() if torch.distributed.get_world_size(group) > 1: - return broadcast_2d_tensor(tensor, get_model_parallel_src_rank(), group) + return broadcast_2d_tensor(tensor, get_model_parallel_src_rank(), group, dtype=dtype) + + return tensor + + +def broadcast_2d_tensor_within_pp(tensor, dtype=torch.float32): + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + return broadcast_2d_tensor( + tensor, + parallel_state.get_pipeline_model_parallel_last_rank(), + parallel_state.get_pipeline_model_parallel_group(), + dtype=dtype, + ) return tensor diff --git a/nemo_aligner/utils/text_generation_utils.py b/nemo_aligner/utils/text_generation_utils.py index 3fd5a433f..ae95f3e34 100644 --- a/nemo_aligner/utils/text_generation_utils.py +++ b/nemo_aligner/utils/text_generation_utils.py @@ -14,10 +14,65 @@ """Utilities for generating text.""" +from typing import Any, List import torch + +from megatron.core import parallel_state +from nemo.collections.nlp.modules.common.text_generation_strategy import GPTModelTextGenerationStrategy from nemo.utils import logging +from nemo_aligner.utils.distributed import broadcast_2d_tensor_within_pp + + +class TrackLengthGPTModelTextGenerationStrategy(GPTModelTextGenerationStrategy): + """ + Text generation strategy that tracks the length of the generated text. + + TODO This is a temporary workaround until NeMo's `generate()` function returns this information. + """ + + def __init__(self, model: Any, context_lengths: torch.Tensor, max_length: int): + super().__init__(model) + self._context_lengths = context_lengths + self._max_length = max_length + self._end_idx = torch.full_like(context_lengths, fill_value=-1) + + def end_of_generation_condition( + self, tokens: torch.Tensor, prev: torch.Tensor, eod_id: int, end_strings: List[str] + ) -> torch.Tensor: + is_end = super().end_of_generation_condition(tokens=tokens, prev=prev, eod_id=eod_id, end_strings=end_strings) + assert len(is_end) == len(tokens) + if len(tokens) != len(self._context_lengths): + raise RuntimeError( + "Batch size mismatch: the `context_lengths` tensor provided in the constructor has batch size " + f"{len(self._context_lengths)}, while the generated tokens have batch size {len(tokens)}" + ) + context_length = tokens.size(1) - 1 # the input tokens come from `tokens[:, : context_length + 1]` + started = self._context_lengths <= context_length + # The generation ends right now when three conditions hold: + # - it has started + # - the end generation is triggered now + # - it did *not* end before + self._end_idx = torch.where(started & is_end & (self._end_idx < 0), context_length, self._end_idx) + return is_end + + def get_lengths(self) -> torch.Tensor: + """ + Return the total lengths of the generated sequences, in # of tokens. + + The total length of a generated sequence counts both: + * the context tokens (i.e., the input prompt) + * the token(s) that ended generation, if any (e.g. the `EOS` token or the token(s) corresponding to + an element of `sampling_params.end_strings`) + """ + lengths = None + if parallel_state.is_pipeline_last_stage(): # only the last stage actually has access to lengths + lengths = torch.where(self._end_idx >= 0, self._end_idx + 1, self._context_lengths + self._max_length) + lengths = lengths.to(torch.int64).view((-1, 1)) + lengths = broadcast_2d_tensor_within_pp(lengths, dtype=torch.int64) + return lengths.flatten() + def pad_batch(batch, pad_id): """batch each element of the batch to be the size of the longest sequence diff --git a/nemo_aligner/utils/utils.py b/nemo_aligner/utils/utils.py index 82206e2a4..e4f43a671 100644 --- a/nemo_aligner/utils/utils.py +++ b/nemo_aligner/utils/utils.py @@ -184,44 +184,6 @@ def calculate_response_lengths(tokens, eos_id): return (tokens != eos_id).sum(-1) -def calculate_dialogue_response_lengths( - tokens, prompt_lengths, tokenizer, end_strings, max_generation_length, max_sequence_length -): - # for EOS - eos_length = calculate_response_lengths(tokens, tokenizer.eos_id) - - if "" in end_strings: - # for the extra_id_1 - extra_id_1_idx = tokenizer.text_to_ids("")[-1] - mask = tokens == extra_id_1_idx - - # take the last extra id token index(assumes we are not padding with extra_id_1) - length_with_extra_id_1 = torch.argmax( - mask * torch.arange(tokens.size(-1), device=torch.cuda.current_device()), dim=-1 - ) - - # if it terminated on the extra token id, then it must have been generated by the model, otherwise it couldn't have - length_with_extra_id_1 = torch.where( - length_with_extra_id_1 >= prompt_lengths, length_with_extra_id_1, torch.iinfo(torch.int32).max - ) - - # either terminated using eos id or extra id 1 - lengths = torch.minimum(eos_length, length_with_extra_id_1) - else: - lengths = eos_length - - # we also want the model to learn EOS or extra id 1 - lengths = lengths + 1 - # Ensure we never go over `length_params.max_length`. Note that this means the response may not necessarily - # end with EOS / extra_id_1 (we should not enforce it as PPO training requires the real generated token). - max_lengths = prompt_lengths + max_generation_length - lengths = torch.minimum(lengths, max_lengths) - - # Prompts' max size and `max_length` should be such that we never exceed the encoder input size. - assert (lengths <= max_sequence_length).all() - return lengths - - def configure_batch_sizes(mbs, gbs, dp=1): app_state = AppState() _reconfigure_microbatch_calculator(