Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rejection sampling clean #218

Open
wants to merge 34 commits into
base: main
Choose a base branch
from

Conversation

abukharin3
Copy link

What does this PR do ?

Adds the rejection sampling algorithm.

Changelog

  • Please update the CHANGELOG.md under next version with high level changes in this PR.

Usage

read -r -d '' cmd_ppo <<EOF
wandb login ${WANDB_API_KEY}
&& cd ${NEMO_RLHF_DIR}
&& export PYTHONPATH="${NEMO_RLHF_DIR}:${PYTHONPATH}"
&& export HYDRA_FULL_ERROR=1
&& export CUDA_LAUNCH_BLOCKING=1
&& export PYTRITON_HOME=/pytriton_cache
&& export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
&& python -u examples/nlp/gpt/train_gpt_rs_actor.py
--config-path=${CONF_DIR}
--config-name=${CONFIG_NAME}
"model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}"
pretrained_checkpoint.restore_from_path="${ACTOR_NEMO_FILE}"
exp_manager.checkpoint_callback_params.save_top_k=1
exp_manager.explicit_log_dir="${ACTOR_LOG_DIR}"
exp_manager.create_wandb_logger=True
exp_manager.wandb_logger_kwargs.name="${ACTOR_NAME}"
exp_manager.wandb_logger_kwargs.project=${WANDB_PROJECT}
++exp_manager.max_time_per_run="00:03:30:00"
trainer.rs.max_epochs=1
trainer.rs.max_steps=313
trainer.rs.val_check_interval=4
trainer.num_nodes=8
trainer.devices=8
++model.tensor_model_parallel_size=4
model.global_batch_size=${ACTOR_GBS}
model.micro_batch_size=1
model.optim.lr="\${multiply:${ACTOR_LR},1.001}"
model.optim.sched.warmup_steps=0
model.optim.sched.constant_steps=312
model.optim.sched.min_lr=${ACTOR_LR}
model.optim.weight_decay=0.01
model.rs.num_rollout_samples=${NUM_ROLLOUTS}
model.rs.rollout_micro_batch_size=8
model.rs.forward_micro_batch_size=8
model.rs.val_rollout_micro_batch_size=8
model.data.data_impl=jsonl
remote_critic_rm.reward_model.ip=${host_critic}
remote_critic_rm.reward_model.port=${CRITIC_PORT}
model.rs.num_rollout_per_prompt=4
model.rs.num_select=1
EOF

Before your PR is "Ready for review"

Pre checks:

  • [Y] Make sure you read and followed Contributor guidelines
  • [N] Did you write any new necessary tests?
  • [N] Did you add or update any necessary documentation? Make sure to also update the NeMo Framework User Guide which contains the tutorials

Checklist when contributing a new algorithm

  • [Y] Does the trainer resume and restore model state all states?
  • [Y] Does the trainer support all parallelism techniques(PP, TP, DP)?
  • [Y] Does the trainer support max_steps=-1 and validation?
  • [Y] Does the trainer only call APIs defined in alignable_interface.py?
  • [Y] Does the trainer have proper logging?

Additional Information

  • Related to # (issue)

@github-actions github-actions bot added documentation Improvements or additions to documentation Utils Algorithms labels Jun 24, 2024
CHANGELOG.md Outdated Show resolved Hide resolved
Copy link
Author

@abukharin3 abukharin3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Olivier's Review

CHANGELOG.md Outdated Show resolved Hide resolved
docs/user-guide/index.rst Outdated Show resolved Hide resolved
docs/user-guide/rs.rst Outdated Show resolved Hide resolved
docs/user-guide/rs.rst Outdated Show resolved Hide resolved
docs/user-guide/rs.rst Outdated Show resolved Hide resolved
nemo_aligner/models/nlp/gpt/reward_critic_clients.py Outdated Show resolved Hide resolved
nemo_aligner/utils/ppo_utils.py Show resolved Hide resolved
nemo_aligner/utils/ppo_utils.py Outdated Show resolved Hide resolved
nemo_aligner/utils/ppo_utils.py Outdated Show resolved Hide resolved
nemo_aligner/utils/train_script_utils.py Outdated Show resolved Hide resolved
RESULTS_DIR="critic_results_dir"
export PYTHONPATH="${NEMO_RLHF_DIR}:${PYTHONPATH}" \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this say

Suggested change
export PYTHONPATH="${NEMO_RLHF_DIR}:${PYTHONPATH}" \
export PYTHONPATH="${GPFS}:${PYTHONPATH}" \

?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it should.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once all comments are addressed please make sure to run the tutorial scripts to check that they work.

docs/user-guide/rs.rst Outdated Show resolved Hide resolved
docs/user-guide/rs.rst Outdated Show resolved Hide resolved
docs/user-guide/rs.rst Outdated Show resolved Hide resolved
Comment on lines 71 to 72
--config-path=${CONF_DIR} \
--config-name=${CONFIG_NAME} \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should drop these so the example works OOTB

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

docs/user-guide/rs.rst Show resolved Hide resolved
docs/user-guide/rs.rst Outdated Show resolved Hide resolved
docs/user-guide/rs.rst Outdated Show resolved Hide resolved
examples/nlp/gpt/train_gpt_rs_actor.py Outdated Show resolved Hide resolved
nemo_aligner/utils/train_script_utils.py Outdated Show resolved Hide resolved
'''

if num_rollout_samples % rollout_micro_batch_size != 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this function to use divide.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this function is still needed, as there are two cases to consider when calculating the mbs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, sorry maybe i misunderstand -- why not just stack everything like you're doing in RStrainer and then try our best to cut the batch down to rollout mbs? if it's not divisible cleanly we can do a min

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed compute_mbs and stack everything in RStrainer like mentioned.

nemo_aligner/utils/train_script_utils.py Outdated Show resolved Hide resolved
@abukharin3 abukharin3 force-pushed the rejection_sampling_clean branch 2 times, most recently from ef0a70a to 04b421d Compare September 18, 2024 00:54
Signed-off-by: Chris Alexiuk <chris@alexiuk.ca>
Co-authored-by: Chris Alexiuk <chris@alexiuk.ca>
Signed-off-by: abukharin <abukharin@nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Gerald Shen <geshen@nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: abukharin <abukharin@nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Alexander Bukharin added 3 commits September 17, 2024 20:55
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Signed-off-by: Alexander Bukharin <abukharin@abukharin-mlt.client.nvidia.com>
Copy link
Collaborator

@odelalleau odelalleau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a pass through previous comments + some new ones

docs/user-guide/index.rst Outdated Show resolved Hide resolved
docs/README.md Outdated Show resolved Hide resolved
docs/user-guide/rs.rst Outdated Show resolved Hide resolved
docs/user-guide/rs.rst Outdated Show resolved Hide resolved
examples/nlp/gpt/conf/gpt_rs_actor.yaml Outdated Show resolved Hide resolved
def load_state_dict(self, state_dict):
self.step = state_dict["step"]
self.consumed_samples = state_dict["consumed_samples"]
self.rs_optimization_step = state_dict["ppo_optimization_step"] # Due to way we save checkpoint
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any update on this?

# Need to pad response tokens before concatenating. Response tokens has prompts concatenated with responses.
current_batch["response_tokens"], rollout_batch["response_tokens"] = pad_batch(current_batch["response_tokens"], rollout_batch["response_tokens"], self.model.tokenizer.eos_id)

current_batch["response_tokens"] = torch.concatenate([current_batch["response_tokens"], rollout_batch["response_tokens"]], dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any thoughts on this? (not a huge deal but seems a bit cleaner unless I'm missing something)

nemo_aligner/algorithms/rs.py Outdated Show resolved Hide resolved
nemo_aligner/utils/train_script_utils.py Outdated Show resolved Hide resolved
RESULTS_DIR="critic_results_dir"
export PYTHONPATH="${NEMO_RLHF_DIR}:${PYTHONPATH}" \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once all comments are addressed please make sure to run the tutorial scripts to check that they work.

README.md Show resolved Hide resolved
abukharin3 and others added 6 commits September 19, 2024 15:36
Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com>
Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com>
Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com>
Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com>
Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com>
docs/user-guide/rs.rst Outdated Show resolved Hide resolved
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
num_rollout_per_prompt=cfg.model.rs.num_rollout_per_prompt,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might just be because it's still WIP on your side, but just to be sure it's not overlooked: you renamed the config options (num_rollout_per_prompt and top_n_rollouts) as per my suggestion, but haven't updated the code + doc accordingly yet.

Copy link
Author

@abukharin3 abukharin3 Sep 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it was WIP, the code and doc are updated now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Algorithms documentation Improvements or additions to documentation Utils
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants