Skip to content

Commit

Permalink
Fix gather when collecting 'num_input_tokens_seen' (#31974)
Browse files Browse the repository at this point in the history
* Move token count to device before gathering

* Run 'make style; make quality'
  • Loading branch information
CodeCreator authored Jul 16, 2024
1 parent c22efa6 commit e391706
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2245,12 +2245,17 @@ def _inner_training_loop(
"a `main_input_name` attribute to the model class you are using."
)
else:
input_device = inputs[main_input_name].device
self.state.num_input_tokens_seen += torch.sum(
self.accelerator.gather(
torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64)
self.state.num_input_tokens_seen += (
torch.sum(
self.accelerator.gather(
torch.tensor(
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
)
)
)
).item()
.cpu()
.item()
)
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
Expand Down

0 comments on commit e391706

Please sign in to comment.