Skip to content

Commit

Permalink
Fix computation of response length in the general case (#19)
Browse files Browse the repository at this point in the history
The previous logic only worked for EOS / <extra_id_1>

Signed-off-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com>
  • Loading branch information
odelalleau committed Jan 27, 2024
1 parent d06b23f commit 003b655
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 89 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
## [Next Version]

### New features and optimizations
- Added public-facing official Dockerfile for NeMo-Aligner
- Memory optimization in PPO that helps avoid OOM in the actor when sending training data to the critic
- SFT: added support for custom validation metrics based on model generations
- Added public-facing official Dockerfile for NeMo-Aligner.
- PPO: memory optimization to help avoid OOM in the actor when sending training data to the critic.
- PPO: it is now possible to use a custom end string in `sampling_params.end_strings` that is different from `<extra_id_1>`.
- SFT: added support for custom validation metrics based on model generations.

### Breaking changes

Expand Down
12 changes: 6 additions & 6 deletions examples/nlp/gpt/conf/gpt_ppo_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ trainer:

ppo:
max_steps: -1 # max PPO steps (-1 to go through the whole train set)
val_check_interval: 10
val_check_interval: 10
save_interval: ${.val_check_interval}
gradient_clip_val: 1.0

Expand Down Expand Up @@ -49,15 +49,15 @@ remote_critic_rm:
# must match the same flag in the critic config
combine_rm_and_critic_server: True

# reward model server, specify if
# reward model server, specify if
# combine_rm_and_critic server is False
reward_model:
name: reward_model
ip: localhost
port: 5555

critic:
name:
name:
train: critic_train
infer: critic_infer
save: critic_save
Expand Down Expand Up @@ -146,15 +146,15 @@ model:

# miscellaneous
seed: 1234

optim:
name: distributed_fused_adam
bucket_cap_mb: 200
overlap_grad_sync: False
contiguous_grad_buffer: True
lr: 9e-7
weight_decay: 0.1
betas:
weight_decay: 0.1
betas:
- 0.9
- 0.98
sched:
Expand Down
61 changes: 28 additions & 33 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,19 @@
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo_aligner.models.alignable_interface import AlignableGenerativeInterface
from nemo_aligner.utils.distributed import (
broadcast_2d_tensor,
broadcast_2d_tensor_within_pp,
calculate_distributed_entropy,
from_parallel_logits_to_logprobs,
)
from nemo_aligner.utils.text_generation_utils import TrackLengthGPTModelTextGenerationStrategy
from nemo_aligner.utils.train_utils import (
grad_reductions,
prepare_for_training_step,
set_eval,
set_sync_funcs,
set_train,
)
from nemo_aligner.utils.utils import (
calculate_dialogue_response_lengths,
configure_batch_sizes,
cpu_weight_swap,
masked_mean,
offload_distributed_adam,
)
from nemo_aligner.utils.utils import configure_batch_sizes, cpu_weight_swap, masked_mean, offload_distributed_adam


class MegatronGPTActorModel(MegatronGPTModel, AlignableGenerativeInterface):
Expand All @@ -65,15 +60,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
# sampling parameters for generation
self._sampling_params = OmegaConf.to_container(self.cfg.ppo.sampling_params, resolve=True)

# Safety check until https://github.com/NVIDIA/NeMo-Aligner/pull/19 is merged.
valid_end_strings = ["<|endoftext|>", "<extra_id_1>"]
for end_string in self._sampling_params["end_strings"]:
if end_string not in valid_end_strings:
raise NotImplementedError(
"Currently only '<|endoftext|>' and '<extra_id_1>' are allowed in `sampling_params.end_strings`, "
f"but found '{end_string}'"
)

self.to_offload_adam_states = self.cfg.ppo.offload_adam_states
self.entropy_bonus = self.cfg.ppo.entropy_bonus
self.ratio_eps = self.cfg.ppo.ratio_eps
Expand Down Expand Up @@ -256,13 +242,9 @@ def get_inference_log_probs(self, response_tokens, forward_micro_batch_size):
)

logprobs = torch.cat(logprobs_list) if len(logprobs_list) > 0 else None
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
# broadcast it from last PP stage to everything else
logprobs = broadcast_2d_tensor(
logprobs,
parallel_state.get_pipeline_model_parallel_last_rank(),
parallel_state.get_pipeline_model_parallel_group(),
)

# Broadcast it from last PP stage to everything else.
logprobs = broadcast_2d_tensor_within_pp(logprobs)

return logprobs

Expand All @@ -280,19 +262,32 @@ def infer(self, inference_batch):
prompt_lengths = inference_batch["length"].cuda(non_blocking=True)
inputs = (prompt_tokens, prompt_lengths)

strategy = TrackLengthGPTModelTextGenerationStrategy(
model=self, context_lengths=prompt_lengths, max_length=self._length_params["max_length"]
)
actor_output = self.generate(
inputs=inputs, length_params=self._length_params, sampling_params=self._sampling_params
inputs=inputs, length_params=self._length_params, sampling_params=self._sampling_params, strategy=strategy
)

response_lengths = strategy.get_lengths()
max_response_length = response_lengths.max().item()

response_tokens = torch.cuda.LongTensor(actor_output["token_ids"])
response_lengths = calculate_dialogue_response_lengths(
tokens=response_tokens,
prompt_lengths=prompt_lengths,
tokenizer=self.tokenizer,
end_strings=self._sampling_params["end_strings"],
max_generation_length=self._length_params["max_length"],
max_sequence_length=self.cfg.encoder_seq_length,
)

# Sanity check to validate response length.
if max_response_length != response_tokens.size(1):
# This may actually happen because NeMo does not always stop generation after `max_length` in batch mode
# => `response_tokens` may contain up to `max_length + max_context_length` tokens.
# TODO once NeMo fixes this issue we should be able to always raise an exception when the check above fails,
# and remove the `if` below.
if (
max_response_length >= response_tokens.size(1)
or response_tokens.size(1) != prompt_lengths.max().item() + self._length_params["max_length"]
):
raise AssertionError(
f"max response length ({max_response_length}) does not match the size of "
f"`response_tokens` ({response_tokens.size(1)})"
)

# TODO(geshen): get nemo generate to return the unaltered log probs
log_probs = self.get_inference_log_probs(
Expand Down
9 changes: 2 additions & 7 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from nemo.utils import AppState, logging
from nemo_aligner.models.alignable_interface import Inferrable, SupervisedInterface
from nemo_aligner.models.nlp.gpt.gpt_reward_model import GPTRewardModel
from nemo_aligner.utils.distributed import broadcast_2d_tensor, gather_tensor
from nemo_aligner.utils.distributed import broadcast_2d_tensor, broadcast_2d_tensor_within_pp, gather_tensor
from nemo_aligner.utils.text_generation_utils import tokenize_batch
from nemo_aligner.utils.train_utils import (
finish_validation_step,
Expand Down Expand Up @@ -415,12 +415,7 @@ def infer(
if self.enable_standardization:
rewards = (rewards - self.rew_mean) / self.rew_std

if parallel_state.get_pipeline_model_parallel_world_size() > 1:
rewards = broadcast_2d_tensor(
rewards,
parallel_state.get_pipeline_model_parallel_last_rank(),
parallel_state.get_pipeline_model_parallel_group(),
)
rewards = broadcast_2d_tensor_within_pp(rewards)

rewards_list = gather_tensor(
rewards, dst=parallel_state.get_data_parallel_src_rank(), group=parallel_state.get_data_parallel_group()
Expand Down
16 changes: 14 additions & 2 deletions nemo_aligner/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,25 @@ def broadcast_2d_tensor(tensor, src, group, dtype=torch.float32):
return tensor


def broadcast_2d_tensor_within_mp(tensor):
def broadcast_2d_tensor_within_mp(tensor, dtype=torch.float32):
"""helper function to broadcast within the model parallel group
"""
group = parallel_state.get_model_parallel_group()

if torch.distributed.get_world_size(group) > 1:
return broadcast_2d_tensor(tensor, get_model_parallel_src_rank(), group)
return broadcast_2d_tensor(tensor, get_model_parallel_src_rank(), group, dtype=dtype)

return tensor


def broadcast_2d_tensor_within_pp(tensor, dtype=torch.float32):
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
return broadcast_2d_tensor(
tensor,
parallel_state.get_pipeline_model_parallel_last_rank(),
parallel_state.get_pipeline_model_parallel_group(),
dtype=dtype,
)

return tensor

Expand Down
55 changes: 55 additions & 0 deletions nemo_aligner/utils/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,65 @@

"""Utilities for generating text."""

from typing import Any, List

import torch

from megatron.core import parallel_state
from nemo.collections.nlp.modules.common.text_generation_strategy import GPTModelTextGenerationStrategy
from nemo.utils import logging

from nemo_aligner.utils.distributed import broadcast_2d_tensor_within_pp


class TrackLengthGPTModelTextGenerationStrategy(GPTModelTextGenerationStrategy):
"""
Text generation strategy that tracks the length of the generated text.
TODO This is a temporary workaround until NeMo's `generate()` function returns this information.
"""

def __init__(self, model: Any, context_lengths: torch.Tensor, max_length: int):
super().__init__(model)
self._context_lengths = context_lengths
self._max_length = max_length
self._end_idx = torch.full_like(context_lengths, fill_value=-1)

def end_of_generation_condition(
self, tokens: torch.Tensor, prev: torch.Tensor, eod_id: int, end_strings: List[str]
) -> torch.Tensor:
is_end = super().end_of_generation_condition(tokens=tokens, prev=prev, eod_id=eod_id, end_strings=end_strings)
assert len(is_end) == len(tokens)
if len(tokens) != len(self._context_lengths):
raise RuntimeError(
"Batch size mismatch: the `context_lengths` tensor provided in the constructor has batch size "
f"{len(self._context_lengths)}, while the generated tokens have batch size {len(tokens)}"
)
context_length = tokens.size(1) - 1 # the input tokens come from `tokens[:, : context_length + 1]`
started = self._context_lengths <= context_length
# The generation ends right now when three conditions hold:
# - it has started
# - the end generation is triggered now
# - it did *not* end before
self._end_idx = torch.where(started & is_end & (self._end_idx < 0), context_length, self._end_idx)
return is_end

def get_lengths(self) -> torch.Tensor:
"""
Return the total lengths of the generated sequences, in # of tokens.
The total length of a generated sequence counts both:
* the context tokens (i.e., the input prompt)
* the token(s) that ended generation, if any (e.g. the `EOS` token or the token(s) corresponding to
an element of `sampling_params.end_strings`)
"""
lengths = None
if parallel_state.is_pipeline_last_stage(): # only the last stage actually has access to lengths
lengths = torch.where(self._end_idx >= 0, self._end_idx + 1, self._context_lengths + self._max_length)
lengths = lengths.to(torch.int64).view((-1, 1))
lengths = broadcast_2d_tensor_within_pp(lengths, dtype=torch.int64)
return lengths.flatten()


def pad_batch(batch, pad_id):
"""batch each element of the batch to be the size of the longest sequence
Expand Down
38 changes: 0 additions & 38 deletions nemo_aligner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,44 +184,6 @@ def calculate_response_lengths(tokens, eos_id):
return (tokens != eos_id).sum(-1)


def calculate_dialogue_response_lengths(
tokens, prompt_lengths, tokenizer, end_strings, max_generation_length, max_sequence_length
):
# for EOS
eos_length = calculate_response_lengths(tokens, tokenizer.eos_id)

if "<extra_id_1>" in end_strings:
# for the extra_id_1
extra_id_1_idx = tokenizer.text_to_ids("<extra_id_1>")[-1]
mask = tokens == extra_id_1_idx

# take the last extra id token index(assumes we are not padding with extra_id_1)
length_with_extra_id_1 = torch.argmax(
mask * torch.arange(tokens.size(-1), device=torch.cuda.current_device()), dim=-1
)

# if it terminated on the extra token id, then it must have been generated by the model, otherwise it couldn't have
length_with_extra_id_1 = torch.where(
length_with_extra_id_1 >= prompt_lengths, length_with_extra_id_1, torch.iinfo(torch.int32).max
)

# either terminated using eos id or extra id 1
lengths = torch.minimum(eos_length, length_with_extra_id_1)
else:
lengths = eos_length

# we also want the model to learn EOS or extra id 1
lengths = lengths + 1
# Ensure we never go over `length_params.max_length`. Note that this means the response may not necessarily
# end with EOS / extra_id_1 (we should not enforce it as PPO training requires the real generated token).
max_lengths = prompt_lengths + max_generation_length
lengths = torch.minimum(lengths, max_lengths)

# Prompts' max size and `max_length` should be such that we never exceed the encoder input size.
assert (lengths <= max_sequence_length).all()
return lengths


def configure_batch_sizes(mbs, gbs, dp=1):
app_state = AppState()
_reconfigure_microbatch_calculator(
Expand Down

0 comments on commit 003b655

Please sign in to comment.