-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rejection sampling clean #218
base: main
Are you sure you want to change the base?
Conversation
ed54e24
to
72bdafb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Olivier's Review
45411e0
to
3bc3bfd
Compare
docs/user-guide/rs.rst
Outdated
RESULTS_DIR="critic_results_dir" | ||
export PYTHONPATH="${NEMO_RLHF_DIR}:${PYTHONPATH}" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this say
export PYTHONPATH="${NEMO_RLHF_DIR}:${PYTHONPATH}" \ | |
export PYTHONPATH="${GPFS}:${PYTHONPATH}" \ |
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it should.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once all comments are addressed please make sure to run the tutorial scripts to check that they work.
docs/user-guide/rs.rst
Outdated
--config-path=${CONF_DIR} \ | ||
--config-name=${CONFIG_NAME} \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should drop these so the example works OOTB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
6f7c825
to
bae350e
Compare
''' | ||
|
||
if num_rollout_samples % rollout_micro_batch_size != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for this function generally -> can you use https://github.com/NVIDIA/Megatron-LM/blob/72008a0460000360ae4542b5411f25175d899b2e/megatron/core/utils.py#L34?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed this function to use divide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this function is still needed, as there are two cases to consider when calculating the mbs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, sorry maybe i misunderstand -- why not just stack everything like you're doing in RStrainer and then try our best to cut the batch down to rollout mbs? if it's not divisible cleanly we can do a min
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed compute_mbs and stack everything in RStrainer like mentioned.
ef0a70a
to
04b421d
Compare
Signed-off-by: Chris Alexiuk <chris@alexiuk.ca> Co-authored-by: Chris Alexiuk <chris@alexiuk.ca> Signed-off-by: abukharin <abukharin@nvidia.com> Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Gerald Shen <geshen@nvidia.com> Signed-off-by: abukharin <abukharin@nvidia.com> Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com> Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com> Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: abukharin <abukharin@nvidia.com> Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com> Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com> Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com> Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com> Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com> Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com> Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com> Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
04b421d
to
ee9ba1b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a pass through previous comments + some new ones
def load_state_dict(self, state_dict): | ||
self.step = state_dict["step"] | ||
self.consumed_samples = state_dict["consumed_samples"] | ||
self.rs_optimization_step = state_dict["ppo_optimization_step"] # Due to way we save checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any update on this?
nemo_aligner/algorithms/rs.py
Outdated
# Need to pad response tokens before concatenating. Response tokens has prompts concatenated with responses. | ||
current_batch["response_tokens"], rollout_batch["response_tokens"] = pad_batch(current_batch["response_tokens"], rollout_batch["response_tokens"], self.model.tokenizer.eos_id) | ||
|
||
current_batch["response_tokens"] = torch.concatenate([current_batch["response_tokens"], rollout_batch["response_tokens"]], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any thoughts on this? (not a huge deal but seems a bit cleaner unless I'm missing something)
docs/user-guide/rs.rst
Outdated
RESULTS_DIR="critic_results_dir" | ||
export PYTHONPATH="${NEMO_RLHF_DIR}:${PYTHONPATH}" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once all comments are addressed please make sure to run the tutorial scripts to check that they work.
Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com>
Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com>
Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com>
Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com>
Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com>
logger=logger, | ||
ckpt_callback=ckpt_callback, | ||
run_timer=timer, | ||
num_rollout_per_prompt=cfg.model.rs.num_rollout_per_prompt, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might just be because it's still WIP on your side, but just to be sure it's not overlooked: you renamed the config options (num_rollout_per_prompt
and top_n_rollouts
) as per my suggestion, but haven't updated the code + doc accordingly yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it was WIP, the code and doc are updated now.
What does this PR do ?
Adds the rejection sampling algorithm.
Changelog
Usage
read -r -d '' cmd_ppo <<EOF
wandb login ${WANDB_API_KEY}
&& cd ${NEMO_RLHF_DIR}
&& export PYTHONPATH="${NEMO_RLHF_DIR}:${PYTHONPATH}"
&& export HYDRA_FULL_ERROR=1
&& export CUDA_LAUNCH_BLOCKING=1
&& export PYTRITON_HOME=/pytriton_cache
&& export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
&& python -u examples/nlp/gpt/train_gpt_rs_actor.py
--config-path=${CONF_DIR}
--config-name=${CONFIG_NAME}
"model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}"
pretrained_checkpoint.restore_from_path="${ACTOR_NEMO_FILE}"
exp_manager.checkpoint_callback_params.save_top_k=1
exp_manager.explicit_log_dir="${ACTOR_LOG_DIR}"
exp_manager.create_wandb_logger=True
exp_manager.wandb_logger_kwargs.name="${ACTOR_NAME}"
exp_manager.wandb_logger_kwargs.project=${WANDB_PROJECT}
++exp_manager.max_time_per_run="00:03:30:00"
trainer.rs.max_epochs=1
trainer.rs.max_steps=313
trainer.rs.val_check_interval=4
trainer.num_nodes=8
trainer.devices=8
++model.tensor_model_parallel_size=4
model.global_batch_size=${ACTOR_GBS}
model.micro_batch_size=1
model.optim.lr="\${multiply:${ACTOR_LR},1.001}"
model.optim.sched.warmup_steps=0
model.optim.sched.constant_steps=312
model.optim.sched.min_lr=${ACTOR_LR}
model.optim.weight_decay=0.01
model.rs.num_rollout_samples=${NUM_ROLLOUTS}
model.rs.rollout_micro_batch_size=8
model.rs.forward_micro_batch_size=8
model.rs.val_rollout_micro_batch_size=8
model.data.data_impl=jsonl
remote_critic_rm.reward_model.ip=${host_critic}
remote_critic_rm.reward_model.port=${CRITIC_PORT}
model.rs.num_rollout_per_prompt=4
model.rs.num_select=1
EOF
Before your PR is "Ready for review"
Pre checks:
Checklist when contributing a new algorithm
max_steps=-1
andvalidation
?Additional Information