From cc137ef6b864d243ca657f44840c9042c9e0b8e5 Mon Sep 17 00:00:00 2001 From: Alexander Bukharin Date: Tue, 10 Sep 2024 13:09:50 -0400 Subject: [PATCH] Fix RS descriptions --- examples/nlp/gpt/conf/gpt_rs_actor.yaml | 28 +++++-------------- examples/nlp/gpt/train_gpt_rs_actor.py | 6 ++-- nemo_aligner/algorithms/rs.py | 22 ++++++--------- .../models/nlp/gpt/megatron_gpt_rs_actor.py | 4 +-- .../models/nlp/gpt/reward_critic_clients.py | 5 ++-- nemo_aligner/utils/ppo_utils.py | 4 +-- 6 files changed, 25 insertions(+), 44 deletions(-) diff --git a/examples/nlp/gpt/conf/gpt_rs_actor.yaml b/examples/nlp/gpt/conf/gpt_rs_actor.yaml index 74aa07ca7..60eb38a12 100644 --- a/examples/nlp/gpt/conf/gpt_rs_actor.yaml +++ b/examples/nlp/gpt/conf/gpt_rs_actor.yaml @@ -34,29 +34,17 @@ trainer: max_epochs: ${.rs.max_epochs} max_steps: ${.rs.max_steps} -remote_critic_rm: +remote_rm: # what to batch the inputs to # set to None if no batching when sending inference to the critic pad_to_length: ${model.encoder_seq_length} - # must match the same flag in the critic config - combine_rm_and_critic_server: True - - # reward model server, specify if - # combine_rm_and_critic server is False + # reward model server reward_model: name: reward_model ip: localhost port: 5555 - critic: - name: - train: critic_train - infer: critic_infer - save: critic_save - ip: localhost - port: 5556 - exp_manager: explicit_log_dir: /results @@ -94,8 +82,8 @@ model: # memory usage forward_micro_batch_size: ${.rollout_micro_batch_size} - num_rollout_per_prompt: 1 - num_select: 1 + num_rollout_per_prompt: 1 # Number of completions to sample per prompt + num_select: 1 # Number of completions to select based on reward and train upon # val generation mbs val_rollout_micro_batch_size: ${.rollout_micro_batch_size} @@ -108,14 +96,12 @@ model: sampling_params: use_greedy: False temperature: 1.0 - top_k: 0 # Is this okay? + top_k: 0 top_p: 1.0 repetition_penalty: 1.0 add_BOS: False all_probs: False compute_logprob: False - # will be used in NeMo version > 1.20.0 - # keeping it for now end_strings: ["<|endoftext|>", ""] # length argument for autoregressive sampling @@ -180,7 +166,7 @@ model: splits_string: null seq_length: ${model.encoder_seq_length} skip_warmup: True - num_workers: 2 + num_workers: 0 dataloader_type: single # cyclic reset_position_ids: False # Reset position ids after end-of-document token reset_attention_mask: False # Reset attention mask after end-of-document token @@ -191,4 +177,4 @@ model: # define fields from the base model's config that should be ignored when merging with this config. overwrite_base_config: data: - data_prefix: True + data_prefix: True \ No newline at end of file diff --git a/examples/nlp/gpt/train_gpt_rs_actor.py b/examples/nlp/gpt/train_gpt_rs_actor.py index d0bb3694e..791ffeca3 100644 --- a/examples/nlp/gpt/train_gpt_rs_actor.py +++ b/examples/nlp/gpt/train_gpt_rs_actor.py @@ -156,7 +156,7 @@ def main(cfg) -> None: logger.log_hyperparams(OmegaConf.to_container(cfg)) - rm_critic = RemoteGPTRMClient(cfg.remote_critic_rm) + rm = RemoteGPTRMClient(cfg.remote_rm) timer = Timer(cfg.exp_manager.get("max_time_per_run")) rs_trainer = RSTrainer( @@ -166,7 +166,7 @@ def main(cfg) -> None: scheduler=scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - rm_critic=rm_critic, + rm=rm, logger=logger, ckpt_callback=ckpt_callback, run_timer=timer, @@ -182,4 +182,4 @@ def main(cfg) -> None: if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/nemo_aligner/algorithms/rs.py b/nemo_aligner/algorithms/rs.py index a0da53e59..982877f55 100644 --- a/nemo_aligner/algorithms/rs.py +++ b/nemo_aligner/algorithms/rs.py @@ -55,7 +55,7 @@ def __init__( generation_iter, duplicate_prompts, num_select, - rm_critic, + rm, ): self.cfg = cfg self.model = model @@ -68,7 +68,7 @@ def __init__( self.generation_iter = generation_iter self.duplicate_prompts = duplicate_prompts self.num_select = num_select - self.rm_critic = rm_critic + self.rm = rm # this timer checks if we should stop training self.run_timer = run_timer @@ -98,7 +98,7 @@ def generate_rs_data(self, rollout_batches): """generate rs specific data for training """ rs_rollout_data = defaultdict(list) - rs_rollout_metrics = defaultdict(lambda: 0) + rs_rollout_metrics = defaultdict(int) num_samples = 0 def post_process_tensor(tensor): @@ -158,14 +158,14 @@ def _run_inference(self, dataloader_iter, num_microbatches, is_validation): inference_batch_duplicated = { 'text':torch.concatenate([inference_batch['text']] * self.duplicate_prompts, dim=0), #input text padded to prompt_llen + max_response length 'length':torch.concatenate([inference_batch['length']] * self.duplicate_prompts, dim=0), - 'attention_mask':inference_batch['attention_mask'], + 'attention_mask':inference_batch['attention_mask'], # Lower trianagular mask, same for ever sample in the batch 'loss_mask':torch.concatenate([inference_batch['loss_mask']] * self.duplicate_prompts, dim=0), 'position_ids':torch.concatenate([inference_batch['position_ids']] * self.duplicate_prompts, dim=0), } for _ in range(self.generation_iter): if current_batch is None: - rollout_batch = self.model.infer(inference_batch_duplicated) # Note that critic mbs has to be set correctly + rollout_batch = self.model.infer(inference_batch_duplicated) current_batch = rollout_batch current_batch["prompt_tokens"] = inference_batch_duplicated["text"] else: @@ -178,7 +178,7 @@ def _run_inference(self, dataloader_iter, num_microbatches, is_validation): current_batch["prompt_lengths"] = torch.concatenate([current_batch["prompt_lengths"], rollout_batch["prompt_lengths"]], dim=0) current_batch["prompt_tokens"] = torch.concatenate([current_batch["prompt_tokens"], inference_batch_duplicated["text"]], dim=0) - rewards = self.rm_critic.infer_rm_critic(rollout_batch).result().detach() + rewards = self.rm.infer_rm(rollout_batch).result().detach() if "rewards" in current_batch: current_batch["rewards"] = torch.concatenate([current_batch["rewards"], rewards], dim=0) else: @@ -191,9 +191,9 @@ def _run_inference(self, dataloader_iter, num_microbatches, is_validation): else: for _, inference_batch in zip(range(num_microbatches), dataloader_iter): - rollout_batch = self.model.infer(inference_batch) # Here we meed to get the prompts as well + rollout_batch = self.model.infer(inference_batch) - rewards = self.rm_critic.infer_rm_critic(rollout_batch).result().detach() + rewards = self.rm.infer_rm(rollout_batch).result().detach() rollout_batch["rewards"] = rewards rollout_batches.append(rollout_batch) @@ -467,12 +467,8 @@ def save(self, extra_candidates=None, is_train_end=False): monitor_candidates = {k: torch.tensor(v, dtype=torch.int32) for k, v in self.state_dict().items()} monitor_candidates.update(extra_candidates) - # future = self.rm_critic.save() - self.ckpt_callback.custom_save(monitor_candidates=monitor_candidates, is_train_end=is_train_end) - # future.result() - self.model.finish_training() def set_max_steps(self): @@ -483,4 +479,4 @@ def set_max_steps(self): @property def epoch(self): - return self.step // self.num_steps_per_epoch + return self.step // self.num_steps_per_epoch \ No newline at end of file diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py index 6bae76d1a..2953da350 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py @@ -55,7 +55,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): super().__init__(cfg, trainer=trainer) self.automatic_optimization = False - self.init_policy_state_dict = None self.distributed_adam_offload_manager = None # length parameters for generation @@ -104,6 +103,7 @@ def fwd_output_and_loss_func(data_iterator, model): parallel_logits = model(batch["tokens"], batch["position_ids"], batch["attention_mask"], labels=None,) + # TODO: This loss depends on the mbs, which is can lead to inconsistencies. See https://github.com/NVIDIA/NeMo/issues/8343. def loss_func(parallel_logits): mask = batch["mask"] tokens = batch["tokens"] @@ -334,4 +334,4 @@ def get_ltor_masks_and_position_ids(self, tokens): ) attention_mask = attention_mask.expand(tokens.size(0), -1, -1, -1) - return attention_mask, loss_mask, position_ids + return attention_mask, loss_mask, position_ids \ No newline at end of file diff --git a/nemo_aligner/models/nlp/gpt/reward_critic_clients.py b/nemo_aligner/models/nlp/gpt/reward_critic_clients.py index 86917d8c9..c15a22b65 100644 --- a/nemo_aligner/models/nlp/gpt/reward_critic_clients.py +++ b/nemo_aligner/models/nlp/gpt/reward_critic_clients.py @@ -193,10 +193,9 @@ def __post_init__(self): self.communicator = HTTPCommunicator.create_http_communicator_from_dict(server_dict) self.communicator.print_server_dict() - self.combine_rm_and_critic_server = self.cfg.combine_rm_and_critic_server self.pad_to_length = self.cfg.pad_to_length - def infer_rm_critic(self, rollout_batch): + def infer_rm(self, rollout_batch): response_tokens = rollout_batch["response_tokens"].cpu() og_seq_length = response_tokens.size(-1) @@ -217,4 +216,4 @@ def infer_rm_critic(self, rollout_batch): self.communicator.send_data_to_server, server_name=self.cfg.reward_model.name, data=send_data ) - return RMFutureResult(rm_future) + return RMFutureResult(rm_future) \ No newline at end of file diff --git a/nemo_aligner/utils/ppo_utils.py b/nemo_aligner/utils/ppo_utils.py index a38f7e127..d69962eae 100644 --- a/nemo_aligner/utils/ppo_utils.py +++ b/nemo_aligner/utils/ppo_utils.py @@ -16,7 +16,7 @@ import torch from nemo_aligner.utils.utils import masked_mean - +import operator def calculate_advantages_and_returns(values, rewards, discount_factor, gae_lambda, mask=None): """calculate the per token advantages and returns for the entire sequence @@ -99,7 +99,7 @@ def select_topk(batch, num_select=1): for i in range(len(unique_prompts)): prompt_idx = torch.arange(len(batch["prompt_tokens"]))[(batch["prompt_tokens"] == unique_prompts[i]).all(1)] sorted_idx = zip(prompt_idx, batch["rewards"][(batch["prompt_tokens"] == unique_prompts[i]).all(1)]) - sorted_idx = sorted(sorted_idx, key=lambda x: x[1]) + sorted_idx = sorted(sorted_idx, key=operator.itemgetter(1)) selected_idx += [x[0].item() for x in sorted_idx[-1 * num_select :]] selected_batch = {k: batch[k][selected_idx] for k in batch.keys()}