Skip to content

Commit

Permalink
Fix RS descriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Bukharin committed Sep 10, 2024
1 parent d45d786 commit cc137ef
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 44 deletions.
28 changes: 7 additions & 21 deletions examples/nlp/gpt/conf/gpt_rs_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,17 @@ trainer:
max_epochs: ${.rs.max_epochs}
max_steps: ${.rs.max_steps}

remote_critic_rm:
remote_rm:
# what to batch the inputs to
# set to None if no batching when sending inference to the critic
pad_to_length: ${model.encoder_seq_length}

# must match the same flag in the critic config
combine_rm_and_critic_server: True

# reward model server, specify if
# combine_rm_and_critic server is False
# reward model server
reward_model:
name: reward_model
ip: localhost
port: 5555

critic:
name:
train: critic_train
infer: critic_infer
save: critic_save
ip: localhost
port: 5556


exp_manager:
explicit_log_dir: /results
Expand Down Expand Up @@ -94,8 +82,8 @@ model:
# memory usage
forward_micro_batch_size: ${.rollout_micro_batch_size}

num_rollout_per_prompt: 1
num_select: 1
num_rollout_per_prompt: 1 # Number of completions to sample per prompt
num_select: 1 # Number of completions to select based on reward and train upon

# val generation mbs
val_rollout_micro_batch_size: ${.rollout_micro_batch_size}
Expand All @@ -108,14 +96,12 @@ model:
sampling_params:
use_greedy: False
temperature: 1.0
top_k: 0 # Is this okay?
top_k: 0
top_p: 1.0
repetition_penalty: 1.0
add_BOS: False
all_probs: False
compute_logprob: False
# will be used in NeMo version > 1.20.0
# keeping it for now
end_strings: ["<|endoftext|>", "<extra_id_1>"]

# length argument for autoregressive sampling
Expand Down Expand Up @@ -180,7 +166,7 @@ model:
splits_string: null
seq_length: ${model.encoder_seq_length}
skip_warmup: True
num_workers: 2
num_workers: 0
dataloader_type: single # cyclic
reset_position_ids: False # Reset position ids after end-of-document token
reset_attention_mask: False # Reset attention mask after end-of-document token
Expand All @@ -191,4 +177,4 @@ model:
# define fields from the base model's config that should be ignored when merging with this config.
overwrite_base_config:
data:
data_prefix: True
data_prefix: True
6 changes: 3 additions & 3 deletions examples/nlp/gpt/train_gpt_rs_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def main(cfg) -> None:

logger.log_hyperparams(OmegaConf.to_container(cfg))

rm_critic = RemoteGPTRMClient(cfg.remote_critic_rm)
rm = RemoteGPTRMClient(cfg.remote_rm)
timer = Timer(cfg.exp_manager.get("max_time_per_run"))

rs_trainer = RSTrainer(
Expand All @@ -166,7 +166,7 @@ def main(cfg) -> None:
scheduler=scheduler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
rm_critic=rm_critic,
rm=rm,
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
Expand All @@ -182,4 +182,4 @@ def main(cfg) -> None:


if __name__ == "__main__":
main()
main()
22 changes: 9 additions & 13 deletions nemo_aligner/algorithms/rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
generation_iter,
duplicate_prompts,
num_select,
rm_critic,
rm,
):
self.cfg = cfg
self.model = model
Expand All @@ -68,7 +68,7 @@ def __init__(
self.generation_iter = generation_iter
self.duplicate_prompts = duplicate_prompts
self.num_select = num_select
self.rm_critic = rm_critic
self.rm = rm

# this timer checks if we should stop training
self.run_timer = run_timer
Expand Down Expand Up @@ -98,7 +98,7 @@ def generate_rs_data(self, rollout_batches):
"""generate rs specific data for training
"""
rs_rollout_data = defaultdict(list)
rs_rollout_metrics = defaultdict(lambda: 0)
rs_rollout_metrics = defaultdict(int)
num_samples = 0

def post_process_tensor(tensor):
Expand Down Expand Up @@ -158,14 +158,14 @@ def _run_inference(self, dataloader_iter, num_microbatches, is_validation):
inference_batch_duplicated = {
'text':torch.concatenate([inference_batch['text']] * self.duplicate_prompts, dim=0), #input text padded to prompt_llen + max_response length
'length':torch.concatenate([inference_batch['length']] * self.duplicate_prompts, dim=0),
'attention_mask':inference_batch['attention_mask'],
'attention_mask':inference_batch['attention_mask'], # Lower trianagular mask, same for ever sample in the batch
'loss_mask':torch.concatenate([inference_batch['loss_mask']] * self.duplicate_prompts, dim=0),
'position_ids':torch.concatenate([inference_batch['position_ids']] * self.duplicate_prompts, dim=0),
}
for _ in range(self.generation_iter):

if current_batch is None:
rollout_batch = self.model.infer(inference_batch_duplicated) # Note that critic mbs has to be set correctly
rollout_batch = self.model.infer(inference_batch_duplicated)
current_batch = rollout_batch
current_batch["prompt_tokens"] = inference_batch_duplicated["text"]
else:
Expand All @@ -178,7 +178,7 @@ def _run_inference(self, dataloader_iter, num_microbatches, is_validation):
current_batch["prompt_lengths"] = torch.concatenate([current_batch["prompt_lengths"], rollout_batch["prompt_lengths"]], dim=0)
current_batch["prompt_tokens"] = torch.concatenate([current_batch["prompt_tokens"], inference_batch_duplicated["text"]], dim=0)

rewards = self.rm_critic.infer_rm_critic(rollout_batch).result().detach()
rewards = self.rm.infer_rm(rollout_batch).result().detach()
if "rewards" in current_batch:
current_batch["rewards"] = torch.concatenate([current_batch["rewards"], rewards], dim=0)
else:
Expand All @@ -191,9 +191,9 @@ def _run_inference(self, dataloader_iter, num_microbatches, is_validation):

else:
for _, inference_batch in zip(range(num_microbatches), dataloader_iter):
rollout_batch = self.model.infer(inference_batch) # Here we meed to get the prompts as well
rollout_batch = self.model.infer(inference_batch)

rewards = self.rm_critic.infer_rm_critic(rollout_batch).result().detach()
rewards = self.rm.infer_rm(rollout_batch).result().detach()
rollout_batch["rewards"] = rewards
rollout_batches.append(rollout_batch)

Expand Down Expand Up @@ -467,12 +467,8 @@ def save(self, extra_candidates=None, is_train_end=False):
monitor_candidates = {k: torch.tensor(v, dtype=torch.int32) for k, v in self.state_dict().items()}
monitor_candidates.update(extra_candidates)

# future = self.rm_critic.save()

self.ckpt_callback.custom_save(monitor_candidates=monitor_candidates, is_train_end=is_train_end)

# future.result()

self.model.finish_training()

def set_max_steps(self):
Expand All @@ -483,4 +479,4 @@ def set_max_steps(self):

@property
def epoch(self):
return self.step // self.num_steps_per_epoch
return self.step // self.num_steps_per_epoch
4 changes: 2 additions & 2 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer=trainer)
self.automatic_optimization = False

self.init_policy_state_dict = None
self.distributed_adam_offload_manager = None

# length parameters for generation
Expand Down Expand Up @@ -104,6 +103,7 @@ def fwd_output_and_loss_func(data_iterator, model):

parallel_logits = model(batch["tokens"], batch["position_ids"], batch["attention_mask"], labels=None,)

# TODO: This loss depends on the mbs, which is can lead to inconsistencies. See https://github.com/NVIDIA/NeMo/issues/8343.
def loss_func(parallel_logits):
mask = batch["mask"]
tokens = batch["tokens"]
Expand Down Expand Up @@ -334,4 +334,4 @@ def get_ltor_masks_and_position_ids(self, tokens):
)
attention_mask = attention_mask.expand(tokens.size(0), -1, -1, -1)

return attention_mask, loss_mask, position_ids
return attention_mask, loss_mask, position_ids
5 changes: 2 additions & 3 deletions nemo_aligner/models/nlp/gpt/reward_critic_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,9 @@ def __post_init__(self):

self.communicator = HTTPCommunicator.create_http_communicator_from_dict(server_dict)
self.communicator.print_server_dict()
self.combine_rm_and_critic_server = self.cfg.combine_rm_and_critic_server
self.pad_to_length = self.cfg.pad_to_length

def infer_rm_critic(self, rollout_batch):
def infer_rm(self, rollout_batch):
response_tokens = rollout_batch["response_tokens"].cpu()
og_seq_length = response_tokens.size(-1)

Expand All @@ -217,4 +216,4 @@ def infer_rm_critic(self, rollout_batch):
self.communicator.send_data_to_server, server_name=self.cfg.reward_model.name, data=send_data
)

return RMFutureResult(rm_future)
return RMFutureResult(rm_future)
4 changes: 2 additions & 2 deletions nemo_aligner/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch
from nemo_aligner.utils.utils import masked_mean

import operator

def calculate_advantages_and_returns(values, rewards, discount_factor, gae_lambda, mask=None):
"""calculate the per token advantages and returns for the entire sequence
Expand Down Expand Up @@ -99,7 +99,7 @@ def select_topk(batch, num_select=1):
for i in range(len(unique_prompts)):
prompt_idx = torch.arange(len(batch["prompt_tokens"]))[(batch["prompt_tokens"] == unique_prompts[i]).all(1)]
sorted_idx = zip(prompt_idx, batch["rewards"][(batch["prompt_tokens"] == unique_prompts[i]).all(1)])
sorted_idx = sorted(sorted_idx, key=lambda x: x[1])
sorted_idx = sorted(sorted_idx, key=operator.itemgetter(1))
selected_idx += [x[0].item() for x in sorted_idx[-1 * num_select :]]

selected_batch = {k: batch[k][selected_idx] for k in batch.keys()}
Expand Down

0 comments on commit cc137ef

Please sign in to comment.