Skip to content

Commit

Permalink
Fix num_micro_batches when using forward_mbs (#267)
Browse files Browse the repository at this point in the history
* fix num_mbs with forward_mbs

Signed-off-by: Shengyang Sun <shengyangs@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* switch to the divide function

Signed-off-by: Shengyang Sun <shengyangs@nvidia.com>

---------

Signed-off-by: Shengyang Sun <shengyangs@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
shengyangs and pre-commit-ci[bot] committed Aug 20, 2024
1 parent 96222d0 commit e809db4
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from megatron.core import parallel_state
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.utils import divide
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer

Expand Down Expand Up @@ -405,8 +406,10 @@ def finish_validation_step(self):
@torch.no_grad()
def get_logprob_batch(self, batch):
seq_length = batch["chosen"].shape[1]
batch_size = batch["chosen"].shape[0]

data_iter = get_iterator_k_split(batch, get_num_microbatches())
num_microbatches = divide(batch_size * 2, self.cfg.dpo.log_prob_forward_micro_batch_size)
data_iter = get_iterator_k_split(batch, num_microbatches)
set_sync_funcs(self, forward_only=True)

fwd_bwd_function = get_forward_backward_func()
Expand All @@ -415,7 +418,7 @@ def get_logprob_batch(self, batch):
forward_step_func=self.get_forward_output_and_loss_func(logprobs_only=True),
data_iterator=data_iter,
model=self.model,
num_microbatches=get_num_microbatches(),
num_microbatches=num_microbatches,
forward_only=True,
seq_length=seq_length,
micro_batch_size=self.cfg.dpo.log_prob_forward_micro_batch_size,
Expand Down

0 comments on commit e809db4

Please sign in to comment.