Skip to content

Commit

Permalink
Update ppo utils file
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Bukharin committed Sep 10, 2024
1 parent cc137ef commit 45411e0
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 11 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Next Version]
- Add rejection sampling to algorithms.

### New features and optimizations
- Added functionality for generating multiple responses for each prompt.
Expand Down
2 changes: 1 addition & 1 deletion docs/user-guide/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
DPO is a simpler alignment method compared to RLHF. DPO introduces a novel parameterization of the reward model in RLHF. This parameterization allows us to extract the corresponding optimal policy.

:ref:`Model Alignment by Rejection Sampling (RS) <model-aligner-rs>`
RS is a simple online alignment algorithm. In RS, the policy model generates several responses. These responses are assigned a scores by the reward model, and the highest scoring responses are used for SFT.
RS is a simple online alignment algorithm. In RS, the policy model generates several responses. These responses are assigned a score by the reward model, and the highest scoring responses are used for SFT.

:ref:`Fine-tuning Stable Diffusion with DRaFT+ <model-aligner-draftp>`
DRaFT+ is an algorithm for fine-tuning text-to-image generative diffusion models by directly backpropagating through a reward model which alleviates the mode collapse issues from DRaFT algorithm and improves diversity through regularization.
10 changes: 5 additions & 5 deletions docs/user-guide/rs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Model Alignment by Rejection Sampling

In this tutorial, we will go through the process of aligning a NeMo framework model using rejection sampling. These can be models such as LLaMa2 or Mistral. Our scripts will work the same way.

RLHF is usually preceded by a Supervised Fine-Tuning (SFT). We should first follow the :ref:`Prerequisite guide <prerequisite>` and the :ref:`SFT guide <sft>`. After obtaining the SFT model, we will also need to train a reward model as in :ref:`PPO guide <ppo>`. We will use the rejection sampling algorithm on the `Anthropic-HH-RLHF <https://huggingface.co/datasets/Anthropic/hh-rlhf>`__ dataset.
Rejection Sampling is usually preceded by a Supervised Fine-Tuning (SFT). We should first follow the :ref:`Prerequisite guide <prerequisite>` and the :ref:`SFT guide <sft>`. After obtaining the SFT model, we will also need to train a reward model as in :ref:`PPO guide <ppo>`. We will use the rejection sampling algorithm on the `Anthropic-HH-RLHF <https://huggingface.co/datasets/Anthropic/hh-rlhf>`__ dataset.

Rejection Sampling Training
############
Expand Down Expand Up @@ -99,10 +99,10 @@ The RS Actor training job contains the master controller that makes the HTTP cal
The above launches the initial and actor server on 1 node with 8 GPUs

Launching Both Servers for RLHF training
Launching Both Servers for Rejection Sampling training
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

You can use slurm to launch the 2 jobs and get them to coordinate together in a full RLHF job via the following:
You can use slurm to launch the 2 jobs and get them to coordinate together in a full Rejection Sampling job via the following:

.. code-block:: bash
Expand Down Expand Up @@ -219,12 +219,12 @@ You can use slurm to launch the 2 jobs and get them to coordinate together in a
wait
The above script runs the reward model critic server on 1 node and the actor on 1 node.
The above script runs the reward model server on 1 node and the actor on 1 node.
It is important to launch all jobs with ``&`` after the srun command, to ensure they do not block each other.
.. note::
Make sure to change the critic arg ``trainer.rs.inference_micro_batch_size`` such that ``trainer.rs.inference_micro_batch_size * DP size <= model.rs.rollout_micro_batch_size``.
Make sure to change the reward model arg ``trainer.rs.inference_micro_batch_size`` such that ``trainer.rs.inference_micro_batch_size * DP size <= model.rs.rollout_micro_batch_size``.
RS Results
%%%%%%%%%%%
Expand Down
12 changes: 8 additions & 4 deletions nemo_aligner/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nemo_aligner.utils.utils import masked_mean
import operator


def calculate_advantages_and_returns(values, rewards, discount_factor, gae_lambda, mask=None):
"""calculate the per token advantages and returns for the entire sequence
Expand Down Expand Up @@ -91,16 +92,19 @@ def create_mask(values, prompt_lengths, response_lengths):

def select_topk(batch, num_select=1):
"""
Function to select the topk responses for each unique prompt in a batch
Function to select the topk responses for each unique prompt in a batch.
Please note that this function samples the same top response for each identical prompt.
Duplicate prompts in the same batch may cause unexpected behavior.
"""
unique_prompts = torch.unique(batch["prompt_tokens"], dim=0)
selected_idx = []

for i in range(len(unique_prompts)):
prompt_idx = torch.arange(len(batch["prompt_tokens"]))[(batch["prompt_tokens"] == unique_prompts[i]).all(1)]
sorted_idx = zip(prompt_idx, batch["rewards"][(batch["prompt_tokens"] == unique_prompts[i]).all(1)])
is_matching_prompt = (batch["prompt_tokens"] == unique_prompts[i]).all(1)
prompt_idx = torch.arange(len(batch["prompt_tokens"]))[is_matching_prompt]
sorted_idx = zip(prompt_idx, batch["rewards"][is_matching_prompt])
sorted_idx = sorted(sorted_idx, key=operator.itemgetter(1))
selected_idx += [x[0].item() for x in sorted_idx[-1 * num_select :]]

selected_batch = {k: batch[k][selected_idx] for k in batch.keys()}
return selected_batch
return selected_batch

0 comments on commit 45411e0

Please sign in to comment.