From 45411e0aa64315c447f3902577de55c249049333 Mon Sep 17 00:00:00 2001 From: Alexander Bukharin Date: Tue, 10 Sep 2024 13:19:31 -0400 Subject: [PATCH] Update ppo utils file --- CHANGELOG.md | 1 - docs/user-guide/index.rst | 2 +- docs/user-guide/rs.rst | 10 +++++----- nemo_aligner/utils/ppo_utils.py | 12 ++++++++---- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 603d93b6f..e1a985bf5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,6 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Next Version] -- Add rejection sampling to algorithms. ### New features and optimizations - Added functionality for generating multiple responses for each prompt. diff --git a/docs/user-guide/index.rst b/docs/user-guide/index.rst index 32a9d9f83..c448845df 100644 --- a/docs/user-guide/index.rst +++ b/docs/user-guide/index.rst @@ -33,7 +33,7 @@ DPO is a simpler alignment method compared to RLHF. DPO introduces a novel parameterization of the reward model in RLHF. This parameterization allows us to extract the corresponding optimal policy. :ref:`Model Alignment by Rejection Sampling (RS) ` - RS is a simple online alignment algorithm. In RS, the policy model generates several responses. These responses are assigned a scores by the reward model, and the highest scoring responses are used for SFT. + RS is a simple online alignment algorithm. In RS, the policy model generates several responses. These responses are assigned a score by the reward model, and the highest scoring responses are used for SFT. :ref:`Fine-tuning Stable Diffusion with DRaFT+ ` DRaFT+ is an algorithm for fine-tuning text-to-image generative diffusion models by directly backpropagating through a reward model which alleviates the mode collapse issues from DRaFT algorithm and improves diversity through regularization. diff --git a/docs/user-guide/rs.rst b/docs/user-guide/rs.rst index 9c9de3c2e..2dfa4a5cc 100644 --- a/docs/user-guide/rs.rst +++ b/docs/user-guide/rs.rst @@ -7,7 +7,7 @@ Model Alignment by Rejection Sampling In this tutorial, we will go through the process of aligning a NeMo framework model using rejection sampling. These can be models such as LLaMa2 or Mistral. Our scripts will work the same way. -RLHF is usually preceded by a Supervised Fine-Tuning (SFT). We should first follow the :ref:`Prerequisite guide ` and the :ref:`SFT guide `. After obtaining the SFT model, we will also need to train a reward model as in :ref:`PPO guide `. We will use the rejection sampling algorithm on the `Anthropic-HH-RLHF `__ dataset. +Rejection Sampling is usually preceded by a Supervised Fine-Tuning (SFT). We should first follow the :ref:`Prerequisite guide ` and the :ref:`SFT guide `. After obtaining the SFT model, we will also need to train a reward model as in :ref:`PPO guide `. We will use the rejection sampling algorithm on the `Anthropic-HH-RLHF `__ dataset. Rejection Sampling Training ############ @@ -99,10 +99,10 @@ The RS Actor training job contains the master controller that makes the HTTP cal The above launches the initial and actor server on 1 node with 8 GPUs -Launching Both Servers for RLHF training +Launching Both Servers for Rejection Sampling training %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -You can use slurm to launch the 2 jobs and get them to coordinate together in a full RLHF job via the following: +You can use slurm to launch the 2 jobs and get them to coordinate together in a full Rejection Sampling job via the following: .. code-block:: bash @@ -219,12 +219,12 @@ You can use slurm to launch the 2 jobs and get them to coordinate together in a wait -The above script runs the reward model critic server on 1 node and the actor on 1 node. +The above script runs the reward model server on 1 node and the actor on 1 node. It is important to launch all jobs with ``&`` after the srun command, to ensure they do not block each other. .. note:: - Make sure to change the critic arg ``trainer.rs.inference_micro_batch_size`` such that ``trainer.rs.inference_micro_batch_size * DP size <= model.rs.rollout_micro_batch_size``. + Make sure to change the reward model arg ``trainer.rs.inference_micro_batch_size`` such that ``trainer.rs.inference_micro_batch_size * DP size <= model.rs.rollout_micro_batch_size``. RS Results %%%%%%%%%%% diff --git a/nemo_aligner/utils/ppo_utils.py b/nemo_aligner/utils/ppo_utils.py index d69962eae..4bada756e 100644 --- a/nemo_aligner/utils/ppo_utils.py +++ b/nemo_aligner/utils/ppo_utils.py @@ -18,6 +18,7 @@ from nemo_aligner.utils.utils import masked_mean import operator + def calculate_advantages_and_returns(values, rewards, discount_factor, gae_lambda, mask=None): """calculate the per token advantages and returns for the entire sequence @@ -91,16 +92,19 @@ def create_mask(values, prompt_lengths, response_lengths): def select_topk(batch, num_select=1): """ - Function to select the topk responses for each unique prompt in a batch + Function to select the topk responses for each unique prompt in a batch. + Please note that this function samples the same top response for each identical prompt. + Duplicate prompts in the same batch may cause unexpected behavior. """ unique_prompts = torch.unique(batch["prompt_tokens"], dim=0) selected_idx = [] for i in range(len(unique_prompts)): - prompt_idx = torch.arange(len(batch["prompt_tokens"]))[(batch["prompt_tokens"] == unique_prompts[i]).all(1)] - sorted_idx = zip(prompt_idx, batch["rewards"][(batch["prompt_tokens"] == unique_prompts[i]).all(1)]) + is_matching_prompt = (batch["prompt_tokens"] == unique_prompts[i]).all(1) + prompt_idx = torch.arange(len(batch["prompt_tokens"]))[is_matching_prompt] + sorted_idx = zip(prompt_idx, batch["rewards"][is_matching_prompt]) sorted_idx = sorted(sorted_idx, key=operator.itemgetter(1)) selected_idx += [x[0].item() for x in sorted_idx[-1 * num_select :]] selected_batch = {k: batch[k][selected_idx] for k in batch.keys()} - return selected_batch + return selected_batch \ No newline at end of file