Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi-GPU loading and inference #190

Merged
merged 4 commits into from
Nov 14, 2023
Merged

Fix multi-GPU loading and inference #190

merged 4 commits into from
Nov 14, 2023

Conversation

casper-hansen
Copy link
Owner

@casper-hansen casper-hansen commented Nov 13, 2023

Resolves #162, Resolves #131, Resolves #143

  • update the use of accelerate methods for multi-GPU (they broke at some point)
  • fix memory issues related to multi-GPU
    • cuda error: an illegal memory access was encountered: This was caused by tensors not being on the right devices. The solution is to put tensors on the right device at the model level - doing it at the linear module level was not a full fix.
  • note: hidden_states.to(attn_output.device) + attn_output may not be needed, needs more testing to make sure it is needed

@pseudotensor
Copy link

cool!

@pseudotensor
Copy link

pseudotensor commented Jan 21, 2024

Hi @casper-hansen ,

I'm running this mode: TheBloke/openchat_3.5-16k-AWQ and while the 'balanced' case runs across all GPUs for model after it loads, any use of the model for large context input ends up only using the first GPU and going GPU OOM.

i.e.

    generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/transformers/generation/utils.py", line 1718, in generate
    return self.greedy_search(
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/transformers/generation/utils.py", line 2579, in greedy_search
    outputs = self(
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1053, in forward
    outputs = self.model(
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/awq/modules/fused/model.py", line 101, in forward
    h, _, past_key_value = layer(
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/awq/modules/fused/block.py", line 65, in forward
    attn_output, _, past_key_value = self.attn.forward(
  File "/home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/awq/modules/fused/attn.py", line 200, in forward
    scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 23.70 GiB. GPU 0 has a total capacty of 47.54 GiB of which 12.35 GiB is free. Including non-PyTorch memory, this process has 35.18 GiB memory in use. Of the allocated memory 33.16 GiB is allocated by PyTorch, and 795.74 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

About 4GB is on each GPU post load, but usage blows up the first GPU and leads to this. Is the forward not distributed?

i.e. post load:

image

Post failure:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants