Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CPU-loaded multi-GPU quantization #289

Merged
merged 2 commits into from
Jan 4, 2024
Merged

Conversation

xNul
Copy link
Contributor

@xNul xNul commented Jan 3, 2024

GPU memory is very valuable. Usually, users don't have a lot of it and so they turn to techniques like quantization to run the models they want with their limited amount of GPU memory. However, quantizing those models uses a lot of GPU memory as well.

Aside from full CPU-based quantization, the best technique I've discovered for quantizing models while limiting GPU memory is to load the model into CPU memory and offload those layers into GPU memory one at a time. AutoAWQ allows for this, but only supports using one GPU for the entire quantization process which limits the amount of GPU memory one can utilize to a single GPU's worth.

Here, I've added a few lines of code which allow models loaded in CPU memory which are offloaded to a GPU for quantization, to be offloaded to all GPUs instead. Layers are automatically distributed across all GPUs on the system via a simple algorithm and all GPUs are utilized in the quantization process.

Thanks to this change, I was able to quantize Mixtral 8x7b on my own system of 4x24GB GPUs. Loading the model with device_map across all GPUs worked, but quantizing the model in addition to that resulted in a CUDA OOM error. Loading the model with device_map on the CPU also worked, but quantizing the model would offload it to a single 24GB GPU which would only make it through half of the quantization process before CUDA OOM'ing. Loading the model with device_map on the CPU and quantizing the model with these changes, offloads the model to all four of my 24GB GPUs and succeeds in quantizing the model.

Model layers are of different sizes and so this simple algorithm is a naive one that can certainly be optimized further to make efficient use of the user's GPU memory. Perhaps this can be done in the future. Additionally, I think one could parallelize the quantization of the CPU-loaded models with this technique and see great performance improvements.

@casper-hansen
Copy link
Owner

That sounds great, makes sense that we can distribute modules like this!

Just to make sure we cause no errors in other models, can you make sure to move the arguments to the common device only if they are not None, i.e. follow this code:

if layer_kwargs.get("attention_mask") is not None:
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to("cuda")

@xNul
Copy link
Contributor Author

xNul commented Jan 3, 2024

@casper-hansen thanks! Good catch! I've added the changes. Anything else?

@casper-hansen
Copy link
Owner

LGTM! Thanks for adding these fixes! I have been able to quantize Mixtral with 1x 48GB GPU and hope this solves any OOM issues during computing scaling for other users that may not have a single GPU with that much VRAM available.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants