diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a930eaccef461c..e3000724ba5176 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1918,7 +1918,12 @@ def _inner_training_loop( "a `main_input_name` attribute to the model class you are using." ) else: - self.state.num_input_tokens_seen += self.accelerator.gather(inputs[main_input_name]).numel() + input_device = inputs[main_input_name].device + self.state.num_input_tokens_seen += torch.sum( + self.accelerator.gather( + torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64) + ) + ).item() if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False