Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inappropriate reduce operation of "num_input_tokens_seen" is prone to get training stuck. #28791

Closed
4 tasks done
YouliangHUANG opened this issue Jan 31, 2024 · 11 comments · Fixed by #29099
Closed
4 tasks done

Comments

@YouliangHUANG
Copy link
Contributor

YouliangHUANG commented Jan 31, 2024

System Info

Trivial

Who can help?

@pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

See src/transformers/trainer.py line 1870
self.state.num_input_tokens_seen += self.accelerator.gather(inputs[main_input_name]).numel()

The length of "inputs[main_input_name]" is not guaranteed to be the same when using ddp, which may make the training process hang. Besides, in a distributed setup, it costs a lot to gather the WHOLE input tensors on different workers. It is better to call .numel() first and then .gather().

Ref: Stuck when using model.generate() and acclerator.gather() in the distributed setting

Expected behavior

Fix:
input_device = inputs[main_input_name].device
self.state.num_input_tokens_seen += torch.sum(self.accelerator.gather(torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64))).item()

@thincal
Copy link

thincal commented Feb 1, 2024

  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1851, in _inner_training_loop
    self.state.num_input_tokens_seen += torch.sum(self.accelerator.gather(
  File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 2159, in gather
    return gather(tensor)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 344, in wrapper
    return function(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 405, in gather
    return _gpu_gather(tensor)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 324, in _gpu_gather
    return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 129, in recursively_apply
    return func(data, *args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 321, in _gpu_gather_one
    torch.distributed.all_gather(output_tensors, tensor)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2806, in all_gather
    work = default_pg.allgather([tensor_list], [tensor])
RuntimeError: No backend type associated with device type cpu

Patched transformer with above hotfix, it seems the error still happened. Could you help have a look ? thanks.

@YouliangHUANG
Copy link
Contributor Author

  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1851, in _inner_training_loop
    self.state.num_input_tokens_seen += torch.sum(self.accelerator.gather(
  File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 2159, in gather
    return gather(tensor)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 344, in wrapper
    return function(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 405, in gather
    return _gpu_gather(tensor)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 324, in _gpu_gather
    return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 129, in recursively_apply
    return func(data, *args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 321, in _gpu_gather_one
    torch.distributed.all_gather(output_tensors, tensor)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2806, in all_gather
    work = default_pg.allgather([tensor_list], [tensor])
RuntimeError: No backend type associated with device type cpu

Patched transformer with above hotfix, it seems the error still happened. Could you help have a look ? thanks.

I also encountered the same problem in the first place, and that's why I added a statement to assign the device using input_device = inputs[main_input_name].device.
As the original code works, assigning the new tensor to the same device should also work as it was. Can you double-check the device assigned to the tensor?

@pacman100
Copy link
Contributor

Thank you @YouliangHUANG for the issue as well as the suggestion to fix it. It makes sense, it would be great if you want to open a PR with the suggested fix.

YouliangHUANG added a commit to YouliangHUANG/transformers-fix-num_input_tokens_seen that referenced this issue Feb 19, 2024
@thincal
Copy link

thincal commented Feb 19, 2024

I also encountered the same problem in the first place, and that's why I added a statement to assign the device using input_device = inputs[main_input_name].device. As the original code works, assigning the new tensor to the same device should also work as it was. Can you double-check the device assigned to the tensor?

@YouliangHUANG

inputs[main_input_name].numel(): 1336
inputs[main_input_name].device: cpu

After applied the fix, same error happened (transformers==4.37.2):

  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1855, in _inner_training_loop
    self.accelerator.gather(
  File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 2161, in gather
    return gather(tensor)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 376, in wrapper
    return function(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 437, in gather
    return _gpu_gather(tensor)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 356, in _gpu_gather
    return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 135, in recursively_apply
    return func(data, *args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 353, in _gpu_gather_one
    torch.distributed.all_gather(output_tensors, tensor)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2806, in all_gather
    work = default_pg.allgather([tensor_list], [tensor])
RuntimeError: No backend type associated with device type cpu

@thincal
Copy link

thincal commented Feb 19, 2024

it works now after force set the device as 'cuda', so it seems that original error is caused by the allgather op not supported in cpu device ?

@YouliangHUANG
Copy link
Contributor Author

YouliangHUANG commented Feb 19, 2024

it works now after force set the device as 'cuda', so it seems that original error is caused by the allgather op not supported in cpu device ?

@thincal Please check your backend type, and refer to https://pytorch.org/docs/stable/distributed.html for more details.

@thincal
Copy link

thincal commented Feb 19, 2024

it works now after force set the device as 'cuda', so it seems that original error is caused by the allgather op not supported in cpu device ?

@thincal Please check your backend type, and refer to https://pytorch.org/docs/stable/distributed.html for more details.

Yes, it's the nccl backend used, which doesn't support cpu device.

@thincal
Copy link

thincal commented Feb 19, 2024

The length of "inputs[main_input_name]" is not guaranteed to be the same when using ddp, which may make the training process hang.

so what change is solving this problem ?

@YouliangHUANG
Copy link
Contributor Author

The length of "inputs[main_input_name]" is not guaranteed to be the same when using ddp, which may make the training process hang.

so what change is solving this problem ?

torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64)
@thincal This code will create a tensor with the size of 1, which records how many input tokens there are in the local worker. Therefore the tensor length is aligned and can be gathered through self.accelerator.gather and then sum into the total number.

@thincal
Copy link

thincal commented Feb 19, 2024

The length of "inputs[main_input_name]" is not guaranteed to be the same when using ddp, which may make the training process hang.

so what change is solving this problem ?

torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64) @thincal This code will create a tensor with the size of 1, which records how many input tokens there are in the local worker. Therefore the tensor length is aligned and can be gathered through self.accelerator.gather and then sum into the total number.

OK, that's great. But it seems that the device should be decided according to the ddp backend ?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

ArthurZucker pushed a commit that referenced this issue Mar 25, 2024
fix the behavior of collecting 'num_input_tokens_seen'

See #28791 for more details.
hovnatan pushed a commit to hovnatan/transformers that referenced this issue Mar 27, 2024
…9099)

fix the behavior of collecting 'num_input_tokens_seen'

See huggingface#28791 for more details.
itazap pushed a commit that referenced this issue May 14, 2024
fix the behavior of collecting 'num_input_tokens_seen'

See #28791 for more details.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants