From b29a806be59191685f6499df94596f06f7ae4243 Mon Sep 17 00:00:00 2001 From: Alexander Bukharin Date: Fri, 20 Sep 2024 20:09:11 -0400 Subject: [PATCH] Remove logprob computation functions --- docs/user-guide/rs.rst | 6 +- examples/nlp/gpt/train_gpt_rs_actor.py | 2 +- nemo_aligner/algorithms/rs.py | 6 +- .../models/nlp/gpt/megatron_gpt_rs_actor.py | 76 ------------------- 4 files changed, 7 insertions(+), 83 deletions(-) diff --git a/docs/user-guide/rs.rst b/docs/user-guide/rs.rst index 7f790d3d..9fb61fc4 100644 --- a/docs/user-guide/rs.rst +++ b/docs/user-guide/rs.rst @@ -14,7 +14,7 @@ Rejection Sampling Training After you have fine-tuned a GPT model using Supervised Fine-Tuning (SFT), and trained a reward model as explained in the preceding section, you can start aligning the policy using rejection sampling. -During rejection sampling training, we have two models interacting with each other, which Aligner runs in separate jobs:: +During rejection sampling training, we have two models interacting with each other, which Aligner runs in separate jobs: #. The Policy Network: This is the model we are training, and it should start from an SFT model. #. The Reward Model (RM): This model takes a prompt concatenated with a response as input, and outputs a single scalar value: the reward, which the rejection sampling algorithm will try to maximize. @@ -93,7 +93,7 @@ The RS Actor training job contains the master controller that makes the HTTP cal remote_critic_rm.reward_model.ip=${host_critic} \ remote_critic_rm.reward_model.port=${CRITIC_PORT} \ model.rs.num_rollout_per_prompt=8 \ - model.rs.num_select=1 + model.rs.top_n_rollouts=1 The above launches the initial and actor server on 1 node with 8 GPUs @@ -210,7 +210,7 @@ You can use slurm to launch the 2 jobs and get them to coordinate together in a remote_critic_rm.reward_model.ip=${host_critic} \ remote_critic_rm.reward_model.port=${CRITIC_PORT} \ model.rs.num_rollout_per_prompt=8 \ - model.rs.num_select=1 + model.rs.top_n_rollouts=1 EOF srun --het-group=1 -o $PPO_OUTFILE -e $PPO_ERRFILE --container-image=${CONTAINER} $MOUNTS bash -c "${cmd_rs}" & diff --git a/examples/nlp/gpt/train_gpt_rs_actor.py b/examples/nlp/gpt/train_gpt_rs_actor.py index 6efa2a58..f01e7073 100644 --- a/examples/nlp/gpt/train_gpt_rs_actor.py +++ b/examples/nlp/gpt/train_gpt_rs_actor.py @@ -160,7 +160,7 @@ def main(cfg) -> None: ckpt_callback=ckpt_callback, run_timer=timer, num_rollout_per_prompt=cfg.model.rs.num_rollouts_per_prompt, - num_select=cfg.model.rs.top_n_rollouts, + top_n_rollouts=cfg.model.rs.top_n_rollouts, ) if custom_trainer_state_dict is not None: diff --git a/nemo_aligner/algorithms/rs.py b/nemo_aligner/algorithms/rs.py index 3113f902..572bbe85 100644 --- a/nemo_aligner/algorithms/rs.py +++ b/nemo_aligner/algorithms/rs.py @@ -53,7 +53,7 @@ def __init__( ckpt_callback, run_timer, num_rollout_per_prompt, - num_select, + top_n_rollouts, rm, ): self.cfg = cfg @@ -65,7 +65,7 @@ def __init__( self.logger = logger self.ckpt_callback = ckpt_callback self.num_rollout_per_prompt = num_rollout_per_prompt - self.num_select = num_select + self.top_n_rollouts = top_n_rollouts self.rm = rm # this timer checks if we should stop training @@ -186,7 +186,7 @@ def _run_inference(self, dataloader_iter, num_microbatches, is_validation): current_batch["rewards"] = torch.concatenate([current_batch["rewards"], rewards], dim=0) else: current_batch["rewards"] = rewards - rollout_batch = select_topk(current_batch, self.num_select) + rollout_batch = select_topk(current_batch, self.top_n_rollouts) rollout_batches.append(rollout_batch) full_batches.append(current_batch) diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py index aaaa6541..d56e5c90 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py @@ -173,55 +173,6 @@ def finish_training(self): """no need to offload adam states here """ - # inference calls - def get_logprob_output_only_func(self, inference_only=True): - fwd_output_only_func = self.get_forward_output_only_func() - - def log_prob_output_only_func(dataloader_iter, model): - batch = next(dataloader_iter) - - output_tensor, _ = fwd_output_only_func(iter([batch,]), model) - - def id_func(output_tensor, non_loss_data=True): - logprobs = from_parallel_logits_to_logprobs( - vocab_parallel_logits=output_tensor, target=batch[0], inference_only=inference_only - ) - return logprobs - - return output_tensor, id_func - - return log_prob_output_only_func - - @torch.no_grad() - def get_inference_log_probs(self, response_tokens, forward_micro_batch_size): - set_sync_funcs(self, forward_only=True) - - mbs, seq_length = response_tokens.size() - num_microbatches = divide(mbs, forward_micro_batch_size) - - attention_mask, _, position_ids = self.get_ltor_masks_and_position_ids(response_tokens) - - batch_iter = get_iterator_k_split([response_tokens, attention_mask, position_ids], num_microbatches) - - fwd_bwd_function = get_forward_backward_func() - logprobs_list = fwd_bwd_function( - forward_step_func=self.get_logprob_output_only_func(inference_only=True), - data_iterator=batch_iter, - model=self.model, - num_microbatches=num_microbatches, - forward_only=True, - seq_length=seq_length, - micro_batch_size=forward_micro_batch_size, - collect_non_loss_data=True, - ) - - logprobs = torch.cat(logprobs_list) if len(logprobs_list) > 0 else None - - # Broadcast it from last PP stage to everything else. - logprobs = broadcast_2d_tensor_within_pp(logprobs) - - return logprobs - def prepare_for_inference(self): """normally we would configure the micro batch calculator here but the nemo generation already does the configuration""" @@ -263,11 +214,6 @@ def infer(self, inference_batch): f"`response_tokens` ({response_tokens.size(1)})" ) - # # TODO(geshen): get nemo generate to return the unaltered log probs - # log_probs = self.get_inference_log_probs( - # response_tokens, forward_micro_batch_size=self.forward_micro_batch_size - # ) - rollout_batch = { "response_tokens": response_tokens, "response_lengths": response_lengths, @@ -277,28 +223,6 @@ def infer(self, inference_batch): # return in GPU, trainer needs to move to cpu return rollout_batch - def get_init_policy_logprobs(self, rollout_batches): - init_log_probs = [] - if self.use_peft and self.init_policy_state_dict is None: - # when using adapters instead of full-tuning, the actor is init policy + adapters - with adapter_control(self): - # With adapters disabled (meaning using the init policy), calculate init_log_probs - for rollout_batch in rollout_batches: - init_log_prob = self.get_inference_log_probs( - rollout_batch["response_tokens"].cuda(), forward_micro_batch_size=self.forward_micro_batch_size - ) - init_log_probs.append(init_log_prob) - else: - with cpu_weight_swap(self, self.init_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2): - for rollout_batch in rollout_batches: - init_log_prob = self.get_inference_log_probs( - rollout_batch["response_tokens"].cuda(), forward_micro_batch_size=self.forward_micro_batch_size - ) - init_log_probs.append(init_log_prob) - - # return in GPU, trainer needs to move to cpu - return init_log_probs - def finish_inference(self): # training will onload the adam states, no need to onload it here self._restore_activation_checkpointing_args()