Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for llama.cpp #27712

Open
oobabooga opened this issue Nov 26, 2023 · 16 comments
Open

Add support for llama.cpp #27712

oobabooga opened this issue Nov 26, 2023 · 16 comments
Labels
Core: Modeling Internals of the library; Models. WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@oobabooga
Copy link
Contributor

oobabooga commented Nov 26, 2023

Feature request

I would like to request llama.cpp as a new model backend in the transformers library.

Motivation

llama.cpp offers:

  1. Excellent performance in scenarios where memory bandwidth is an issue, namely CPU inference and GPU + CPU inference.
  2. Support for a wide range of GPU vendors and models.
  3. Adequate quantization accuracy -- I have compared the perplexities of 4-bit GGUF models to GPTQ, AWQ, EXL2, and bitsandbytes and found them to be competitive (link).

By making the transformers library compatible with GGUF models, the llama.cpp performance on consumer hardware could hopefully be integrated with the features available in transformers and its surrounding ecosystem. In particular, it would be interesting to see the following working seamlessly with llama.cpp:

Your contribution

I have implemented a "llamacpp_HF" wrapper in the file below:

https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_hf.py

It makes it possible to use the transformers model.generate with llama.cpp models, and it exemplifies how to make forward calls in llama.cpp and get the logits. It works for perplexity evaluation when logits_all=True is passed while loading the model. I additionally implemented some prefix-matching logic and a hacky way to recognize forward calls for negative prompts to make CFG functional.

For the llama.cpp transformers integration, I recommend the following:

  • Relying on the llama-cpp-python library: https://github.com/abetlen/llama-cpp-python/
  • Requiring the user to manually install llama-cpp-python with the appropriate command for their hardware rather than adding it as a direct requirement to transformers. I believe that's how it already works for GPTQ models, where AutoGPTQ has to be installed manually.
  • In the from_pretrained call, having a LlamaCppConfig object that takes as input arbitrary kwargs that later on get passed to the llama_cpp.Llama model loading call. That would be similar to the BitsAndBytesConfig object that is passed to from_pretrained when load_in_4bit=True is used. Some important parameters are n_gpu_layers and n_ctx; it would be interesting to make this future-proof and allow arbitrary kwargs to be passed to LlamaCppConfig.

I'll tag @younesbelkada who worked with RWKV and AWQ integration in transformers and may find this interesting.

@huggingface huggingface deleted a comment from github-actions bot Dec 27, 2023
@huggingface huggingface deleted a comment from github-actions bot Jan 21, 2024
@younesbelkada
Copy link
Contributor

younesbelkada commented Jan 22, 2024

Hi @oobabooga !
Apologies for my late reply
In general we are very interested in adding new quantization schemes in HF transformers. Currently, we're waiting to merge #26610 in order to make the support for new quantization methods easier for anyone in the future.
We had some internal discussion about adding Llama.cpp inference support in transformers and currently we feel that the LlamaCpp library is quite fast moving to be added in HF transformers making it quite challenging to maintain overall. This is debatable, so feel free to let us know what do you think about it and we can consider adding Llama.cpp after #26610 gets merged

@poedator
Copy link
Contributor

@oobabooga, It should be possible to create externally a subclass to HFTransformers with llama.cpp support, independent from GptqHfQuantizer class. It could be hosted outside transformers.

@younesbelkada
Copy link
Contributor

Just discussed offline with @ArthurZucker - indeed you can import the auto mapping that live here: https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/auto.py and add the new quantizers that would firstly live inside text-generation-webui - if we see that everything is quite stable and not subject to a lot of breaking change we can port the quantizers back in transformers core. How does that sound @oobabooga ? I can also work on a PoC PR in your repo as well

@oobabooga
Copy link
Contributor Author

oobabooga commented Jan 30, 2024

@younesbelkada a PR kickstarting that addition in text-generation-webui would be extremely appreciated. I am not familiar enough with the transformers internals to do it myself -- in particular, porting the llama.cpp cache to transformers has been a blocker in my attempts.

llama_cpp_python has a very complete and comprehensive API. The necessary functions should all be in this file:

https://github.com/abetlen/llama-cpp-python/blob/da003d87681f02475eedb6937443e5f07db889b0/llama_cpp/llama_cpp.py#L1291

After your PoC, I should be able to maintain the code afterwards and accept PRs so it becomes more stable over time.

@huggingface huggingface deleted a comment from github-actions bot Feb 27, 2024
@huggingface huggingface deleted a comment from github-actions bot Mar 25, 2024
@ArthurZucker ArthurZucker added the Core: Modeling Internals of the library; Models. label Mar 25, 2024
@MotivaoCrypto
Copy link

Where are we in support to Llama.CPP with Transformers package?

@endomorphosis
Copy link

I would also like to see this as well.

@LysandreJik
Copy link
Member

Hey! Thanks @oobabooga for the feature request. Before diving into technicalities, I'd like to understand what is the feature you really want as there are many areas where llama.cpp/transformers can be linked and many toolkits out there doing so.

When you say you want to have llama.cpp as a backend for transformers, do you mean that you would like:

  1. To have transformers have llama.cpp/ggml as a backend like we have for torch/tf/jax (Which means a implementation of models in the format)
  2. To have transformers link to llama.cpp underneath, for example with python bindings, so offering the transformers' python API while leveraging llama.cpp under the hood
  3. To have transformers load gguf files and use them with the current backends (so torch, for example).

For 1., I think there is a significant amount of work to enable that and I'm not sure that it would benefit either llama.cpp or transformers so I don't think that's necessarily want you want to do; if it is, please let me know.

For 2., what is the difference with existing toolkits such as llama-cpp-python or transformers?

For 3., that's actually something we've discussed with @ggerganov and which we're exploring right now. This would mean converting back gguf files to fp32 to use with transformers in order to use them within the python ecosystem (for example for training, fine-tuning, LoRa, etc.), before converting them back to the gguf file format afterwards.

Thanks for taking the time!

@oobabooga
Copy link
Contributor Author

I want to be able to use every sampling function available in the transformers library while having llama.cpp as the backend for the forward passes and the cache handling. That means I do not want to simply convert GGUF models to PyTorch on the fly.

In practical terms, I want to be able to use model.generate() with parameters like the following, all while being able to split the workload between the GPU and the CPU for the forward passes with the n_gpu_layers parameter:

  • Contrastive search (through the penalty_alpha parameter)
  • Prompt Lookup Decoding (through the prompt_lookup_num_tokens parameter).
  • CFG (through the guidance_scale parameter).
  • Assisted decoding (using the exact code provided by @gante in his blog post).

So, option 2 seems like what I am looking for.

Option 3 also has an appeal in that Transformers doesn't have the ability to load models with fractional bitrates at the moment (like the EXL2 format or the llama.cpp k-quants do), so loading a q4_K_M in PyTorch format would have its own merit. But that's tangential to my motivation while creating this issue.

@gante
Copy link
Member

gante commented Apr 19, 2024

@oobabooga If I'm getting it right: you'd like to replace the model object with something that runs llama.cpp under the hood, but would have the exact same .generate API and user experience. By making the core object compatible with .generate, you would get the benefit of a single API for both cases and all the generate features of transformers available on the llama.cpp side. Is this correct? 🤗

I think it is doable, but it is not entirely trivial -- .generate depends on several model and model config attributes, so there has to be a wrapper class on the llama.cpp model (inheriting from GenerationMixin) to define those attributes as well as to overwrite a few internal methods as needed. Coincidently, our next .generate project is to refactor it so that it stops being a monolith, which greatly pairs with the (likely) need to overwrite a few details.

I'm out of bandwidth until the aforementioned project is completed, but it seems like a nice project for a new external library. WDYT @LysandreJik ?

@oobabooga
Copy link
Contributor Author

Yes, you understood my goal correctly. Note that most sampling parameters in Transformers already work in my project with GGUF files through the LlamacppHF class, which inherits from PreTrainedModel and returns logits in the format expected by the library in the __call__ method. But it doesn't work with sampling parameters that involve passing and returning cache values to __call__.

@LysandreJik
Copy link
Member

Understood! I'm not closed to the idea, let's see as we split the monolithic generate if this isn't something we can easily solve at the same time.

Option 3 as discussed above is being drafted in #30391

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@endomorphosis
Copy link

bump

@younesbelkada
Copy link
Contributor

FYI option 3 (converting GGUF quants to transformers) has already been landed ! #30391

@huggingface huggingface deleted a comment from github-actions bot Jun 11, 2024
@gante gante added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Jun 14, 2024
@compilade
Copy link

compilade commented Aug 15, 2024

@LysandreJik, @younesbelkada, @kskd1804, @SunMarc

gguf-py can now natively dequantize most quant types to float32 (including i-quants) purely with Numpy, see ggerganov/llama.cpp#8939

This can be used through the gguf.quants.dequantize function (also available as gguf.dequantize).

https://github.com/ggerganov/llama.cpp/blob/5fd89a70ead34d1a17015ddecad05aaa2490ca46/gguf-py/gguf/quants.py#L67

EDIT: the (upstream) gguf-py version has not been bumped yet (so this is not on pypi yet), I'll see if I can do that in the next days.

@SunMarc
Copy link
Member

SunMarc commented Aug 19, 2024

Thanks for the heads-up @compilade ! We already have a PR opened to use the gguf.quants.dequantize function. When this lands on pipy, I will merge the PR !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Modeling Internals of the library; Models. WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests

10 participants