Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All examples fail using ExllamaV2 #583

Closed
dnhkng opened this issue Jan 25, 2024 · 1 comment · Fixed by #584
Closed

All examples fail using ExllamaV2 #583

dnhkng opened this issue Jan 25, 2024 · 1 comment · Fixed by #584
Labels

Comments

@dnhkng
Copy link
Contributor

dnhkng commented Jan 25, 2024

Describe the issue as clearly as possible:

Generating using the examples given on the front page of the repo all generate the same error:

RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.

Steps/code to reproduce the bug:

import outlines

model = outlines.models.exl2(model_name = "models/TinyLlama-1.1B-Chat-v1.0-5.0bpw-h6-exl2", model_kwargs={'num_experts_per_token':1}, device="cpu")

prompt = """You are a sentiment-labelling assistant.
Is the following review positive or negative?

Review: This restaurant is just awesome!
"""

generator = outlines.generate.choice(model, ["Positive", "Negative"])
answer = generator(prompt)

Expected result:

Either "Positive", "Negative"

Error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[34], line 10
      3 prompt = """You are a sentiment-labelling assistant.
      4 Is the following review positive or negative?
      5 
      6 Review: This restaurant is just awesome!
      7 """
      9 generator = outlines.generate.choice(model, ["Positive", "Negative"])
---> 10 answer = generator(prompt)

File ~/miniforge3/envs/exllama/lib/python3.10/site-packages/outlines/generate/api.py:213, in SequenceGenerator.__call__(self, prompts, max_tokens, stop_at, rng, kv_cache)
    211 while True:
    212     try:
--> 213         last_state = next(states)
    214         if max_tokens or stop_sequences:
    215             generated_token_ids = self.get_generated_token_ids(
    216                 init_state, prompts, last_state
    217             )

File ~/miniforge3/envs/exllama/lib/python3.10/site-packages/outlines/generate/generator.py:84, in sequence_generator(token_generator, fsms, init_state, fsm_states, rng)
     81 while True:
     82     allowed_tokens = get_allowed_tokens(fsms, fsm_states)
---> 84     next_token_ids, kv_cache, logits, _ = token_generator(
     85         token_ids,
     86         attention_masks,
     87         kv_cache,
     88         rng=rng,
     89         allowed_tokens=allowed_tokens,
     90     )
     92     token_ids = update_token_ids(token_ids, next_token_ids)
     93     attention_masks = expand_attention_masks(attention_masks)

File ~/miniforge3/envs/exllama/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/miniforge3/envs/exllama/lib/python3.10/site-packages/outlines/generate/generator.py:144, in token_generator.<locals>.generate(token_ids, attention_masks, kv_cache, allowed_tokens, rng)
    139 except IndexError:  # Exceeding the context length
    140     raise IndexError(
    141         "The input length exceeds the context length of the model."
    142     )
--> 144 biased_logits = bias_logits(logits, allowed_tokens)
    145 next_token_ids = sampler(biased_logits, 1, rng)
    147 return next_token_ids, new_kv_cache, logits, biased_logits

File ~/miniforge3/envs/exllama/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/miniforge3/envs/exllama/lib/python3.10/site-packages/outlines/generate/generator.py:287, in bias_logits(logits, allowed_token_ids)
    285 biased_logits = torch.full(logits.shape, -math.inf, device=logits.device)
    286 for i, ids in enumerate(allowed_token_ids):
--> 287     biased_logits[i, ids] = logits[i, ids]
    288 return biased_logits

RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.

Outlines/Python version information:

Note: I have installed outlines with pip, but there is no .__version__ !
>dir(outlines)
['Function',
 '__all__',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 'base',
 'caching',
 'clear_cache',
 'disable_cache',
 'fsm',
 'function',
 'generate',
 'get_cache',
 'models',
 'outlines',
 'prompt',
 'prompts',
 'text',
 'vectorize']

Python 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0]

Context for the issue:

ExllamaV2 is a very good local inference framework.

@dnhkng dnhkng added the bug label Jan 25, 2024
@dnhkng
Copy link
Contributor Author

dnhkng commented Jan 25, 2024

I've fixed it by modifying outlines/generate/generator.py

    285 biased_logits = torch.full(logits.shape, -math.inf, device=logits.device)
    286 for i, ids in enumerate(allowed_token_ids):
--> 287     biased_logits[i, ids] = logits[i, ids]
    288 return biased_logits

to convert the type of the source to the destination:

    285 biased_logits = torch.full(logits.shape, -math.inf, device=logits.device)
    286 for i, ids in enumerate(allowed_token_ids):
--> 287     biased_logits[i, ids] = logits[i, ids].to(biased_logits.dtype)
    288 return biased_logits

Pull request is #584

dnhkng added a commit to dnhkng/outlines that referenced this issue Jan 25, 2024
dnhkng added a commit to dnhkng/outlines that referenced this issue Jan 25, 2024
@rlouf rlouf linked a pull request Jan 25, 2024 that will close this issue
rlouf pushed a commit that referenced this issue Jan 25, 2024
Exllama generates logits in torch Half-dtype, but Outlines requires the
Float-dtype.

This small change converts the logits to the required dtype (whatever
that might be), solving issue #583.

Tested with Exllama on the example code on the github front page, and
#583 is resolved.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant