Skip to content

Commit

Permalink
Remove logprob computation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Bukharin committed Sep 21, 2024
1 parent 7a5037b commit b29a806
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 83 deletions.
6 changes: 3 additions & 3 deletions docs/user-guide/rs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Rejection Sampling Training

After you have fine-tuned a GPT model using Supervised Fine-Tuning (SFT), and trained a reward model as explained in the preceding section, you can start aligning the policy using rejection sampling.

During rejection sampling training, we have two models interacting with each other, which Aligner runs in separate jobs::
During rejection sampling training, we have two models interacting with each other, which Aligner runs in separate jobs:

#. The Policy Network: This is the model we are training, and it should start from an SFT model.
#. The Reward Model (RM): This model takes a prompt concatenated with a response as input, and outputs a single scalar value: the reward, which the rejection sampling algorithm will try to maximize.
Expand Down Expand Up @@ -93,7 +93,7 @@ The RS Actor training job contains the master controller that makes the HTTP cal
remote_critic_rm.reward_model.ip=${host_critic} \
remote_critic_rm.reward_model.port=${CRITIC_PORT} \
model.rs.num_rollout_per_prompt=8 \
model.rs.num_select=1
model.rs.top_n_rollouts=1
The above launches the initial and actor server on 1 node with 8 GPUs

Expand Down Expand Up @@ -210,7 +210,7 @@ You can use slurm to launch the 2 jobs and get them to coordinate together in a
remote_critic_rm.reward_model.ip=${host_critic} \
remote_critic_rm.reward_model.port=${CRITIC_PORT} \
model.rs.num_rollout_per_prompt=8 \
model.rs.num_select=1
model.rs.top_n_rollouts=1
EOF
srun --het-group=1 -o $PPO_OUTFILE -e $PPO_ERRFILE --container-image=${CONTAINER} $MOUNTS bash -c "${cmd_rs}" &
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/gpt/train_gpt_rs_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def main(cfg) -> None:
ckpt_callback=ckpt_callback,
run_timer=timer,
num_rollout_per_prompt=cfg.model.rs.num_rollouts_per_prompt,
num_select=cfg.model.rs.top_n_rollouts,
top_n_rollouts=cfg.model.rs.top_n_rollouts,
)

if custom_trainer_state_dict is not None:
Expand Down
6 changes: 3 additions & 3 deletions nemo_aligner/algorithms/rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
ckpt_callback,
run_timer,
num_rollout_per_prompt,
num_select,
top_n_rollouts,
rm,
):
self.cfg = cfg
Expand All @@ -65,7 +65,7 @@ def __init__(
self.logger = logger
self.ckpt_callback = ckpt_callback
self.num_rollout_per_prompt = num_rollout_per_prompt
self.num_select = num_select
self.top_n_rollouts = top_n_rollouts
self.rm = rm

# this timer checks if we should stop training
Expand Down Expand Up @@ -186,7 +186,7 @@ def _run_inference(self, dataloader_iter, num_microbatches, is_validation):
current_batch["rewards"] = torch.concatenate([current_batch["rewards"], rewards], dim=0)
else:
current_batch["rewards"] = rewards
rollout_batch = select_topk(current_batch, self.num_select)
rollout_batch = select_topk(current_batch, self.top_n_rollouts)

rollout_batches.append(rollout_batch)
full_batches.append(current_batch)
Expand Down
76 changes: 0 additions & 76 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,55 +173,6 @@ def finish_training(self):
"""no need to offload adam states here
"""

# inference calls
def get_logprob_output_only_func(self, inference_only=True):
fwd_output_only_func = self.get_forward_output_only_func()

def log_prob_output_only_func(dataloader_iter, model):
batch = next(dataloader_iter)

output_tensor, _ = fwd_output_only_func(iter([batch,]), model)

def id_func(output_tensor, non_loss_data=True):
logprobs = from_parallel_logits_to_logprobs(
vocab_parallel_logits=output_tensor, target=batch[0], inference_only=inference_only
)
return logprobs

return output_tensor, id_func

return log_prob_output_only_func

@torch.no_grad()
def get_inference_log_probs(self, response_tokens, forward_micro_batch_size):
set_sync_funcs(self, forward_only=True)

mbs, seq_length = response_tokens.size()
num_microbatches = divide(mbs, forward_micro_batch_size)

attention_mask, _, position_ids = self.get_ltor_masks_and_position_ids(response_tokens)

batch_iter = get_iterator_k_split([response_tokens, attention_mask, position_ids], num_microbatches)

fwd_bwd_function = get_forward_backward_func()
logprobs_list = fwd_bwd_function(
forward_step_func=self.get_logprob_output_only_func(inference_only=True),
data_iterator=batch_iter,
model=self.model,
num_microbatches=num_microbatches,
forward_only=True,
seq_length=seq_length,
micro_batch_size=forward_micro_batch_size,
collect_non_loss_data=True,
)

logprobs = torch.cat(logprobs_list) if len(logprobs_list) > 0 else None

# Broadcast it from last PP stage to everything else.
logprobs = broadcast_2d_tensor_within_pp(logprobs)

return logprobs

def prepare_for_inference(self):
"""normally we would configure the micro batch calculator here
but the nemo generation already does the configuration"""
Expand Down Expand Up @@ -263,11 +214,6 @@ def infer(self, inference_batch):
f"`response_tokens` ({response_tokens.size(1)})"
)

# # TODO(geshen): get nemo generate to return the unaltered log probs
# log_probs = self.get_inference_log_probs(
# response_tokens, forward_micro_batch_size=self.forward_micro_batch_size
# )

rollout_batch = {
"response_tokens": response_tokens,
"response_lengths": response_lengths,
Expand All @@ -277,28 +223,6 @@ def infer(self, inference_batch):
# return in GPU, trainer needs to move to cpu
return rollout_batch

def get_init_policy_logprobs(self, rollout_batches):
init_log_probs = []
if self.use_peft and self.init_policy_state_dict is None:
# when using adapters instead of full-tuning, the actor is init policy + adapters
with adapter_control(self):
# With adapters disabled (meaning using the init policy), calculate init_log_probs
for rollout_batch in rollout_batches:
init_log_prob = self.get_inference_log_probs(
rollout_batch["response_tokens"].cuda(), forward_micro_batch_size=self.forward_micro_batch_size
)
init_log_probs.append(init_log_prob)
else:
with cpu_weight_swap(self, self.init_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2):
for rollout_batch in rollout_batches:
init_log_prob = self.get_inference_log_probs(
rollout_batch["response_tokens"].cuda(), forward_micro_batch_size=self.forward_micro_batch_size
)
init_log_probs.append(init_log_prob)

# return in GPU, trainer needs to move to cpu
return init_log_probs

def finish_inference(self):
# training will onload the adam states, no need to onload it here
self._restore_activation_checkpointing_args()
Expand Down

0 comments on commit b29a806

Please sign in to comment.