Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trainning on 8 A100 40GB GPUs? #19

Closed
RoyZhao926 opened this issue Jul 10, 2024 · 6 comments
Closed

trainning on 8 A100 40GB GPUs? #19

RoyZhao926 opened this issue Jul 10, 2024 · 6 comments

Comments

@RoyZhao926
Copy link

Is it feasible to train using 8 A100 GPUs with 40GB? I have encountered GPU out of memory during the pre-training phase.

@RoyZhao926 RoyZhao926 changed the title trainning in 8 A100 40GB GPUs? trainning on 8 A100 40GB GPUs? Jul 11, 2024
@RoyZhao926
Copy link
Author

I found that it can only run on my GPUs when bs is 1. Can you tell me what the configuration of your training resources is? Is it because my computing resources is insufficient?

@machuofan
Copy link
Collaborator

I normally use 8 A100 80G for training. If your GPU memory is limited, you can turn on gradient accumulation and fsdp.

@RoyZhao926
Copy link
Author

Thank you for your reply! Can this model be trained using fp16 or bf16? I found --bf 16 True in your vl_pretrain.sh but it seems that the following error occurs when I use fp16 or bf16 for training:
Traceback (most recent call last): File "Groma/groma/train/train_mem.py", line 13, in <module> train() File "Groma/groma/train/train.py", line 156, in train trainer.train() File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/transformers/trainer.py", line 1539, in train return inner_training_loop( File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/transformers/trainer.py", line 1869, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/transformers/trainer.py", line 2768, in training_step loss = self.compute_loss(model, inputs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/transformers/trainer.py", line 2791, in compute_loss outputs = model(**inputs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, **kwargs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1735, in forward loss = self.module(*inputs, **kwargs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "Groma/groma/model/groma.py", line 245, in forward ddetr_outs = self.perceiver.ddetr_transformer(srcs, masks, return_dict=True) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "Groma/groma/model/ddetr_transformer.py", line 679, in forward outputs = self.extract_feature( File "Groma/groma/model/ddetr_transformer.py", line 530, in extract_feature encoder_outputs = self.encoder( File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/transformers/models/deformable_detr/modeling_deformable_detr.py", line 1238, in forward layer_outputs = encoder_layer( File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/transformers/models/deformable_detr/modeling_deformable_detr.py", line 878, in forward hidden_states, attn_weights = self.self_attn( File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/transformers/models/deformable_detr/modeling_deformable_detr.py", line 709, in forward output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/transformers/models/deformable_detr/modeling_deformable_detr.py", line 564, in multi_scale_deformable_attention sampling_value_l_ = nn.functional.grid_sample( File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/functional.py", line 4304, in grid_sample return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) RuntimeError: "grid_sampler_2d_cuda" not implemented for 'BFloat16'

This seems to be because the underlying implementation of ddetr only has float32 supported, and the same issue was found in its repository. Have you ever tried training with FP16 or BF16?

@machuofan
Copy link
Collaborator

Yes, the model is trained with bf16=True by default. From my experience, the above error is probably caused by incompatible torch and transformers versions. Did you install the same version as in pyproject.toml?

@RoyZhao926
Copy link
Author

Thank you for your response. I did follow pyproject.toml and used the script scripts/vl_pretrain.sh. I found that even I set --bf16 True \ --tf32 True \, the model is still training in float32 (I observed this by printing the network parameters as well as the intermediate results, and wondered if you might have the same problem) ,which may be caused by my findings below.

In my previous discovery, ddetr running non-float32 would report an error:
RuntimeError: “ms_deform_attn_forward_cuda” not implemented for 'Half'
This is probably because it doesn't support mixed precision on its own, also I found that in the roi_align.py, the inputs are forced to be converted to float32, which suggests that its original implementation only considered the float32.

In the end I had to force ddetr to be implemented in fp32 based on @force_fp32 and convert some of the input types in roi_align.py so that the model would be trained in bf16.

@machuofan
Copy link
Collaborator

Thanks for your feedback. Do you mean all the model parameters are trained in fp32, even though you set bf16=True? Or just the ddetr parameters? bf16 works in my local environment as turn on bf16 significantly boosts the training speed compared with bf16 off.

It is reasonable to see features converted to fp32 before input to the ddetr transformer, as according to the underlying cuda implementation, MultiScaleDeformableAttention supports fp16 but not bf16. Normally, this conversion is automatically done. But I think use @force_fp32 to force conversion is still fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants