Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tokenizer Aware FSM Construction #383

Closed
simon-mo opened this issue Nov 21, 2023 · 1 comment
Closed

Tokenizer Aware FSM Construction #383

simon-mo opened this issue Nov 21, 2023 · 1 comment

Comments

@simon-mo
Copy link

What behavior of the library made you think about the improvement?

Please correct me if I'm wrong, currently the FSM construction code is O(vocab)xO(state). Even with numba's acceleration it is still pretty time consuming. For example, the IP address example took 9s for me.

One observation, the tokenizer's vocabulary is often times way bigger than the collection of characters due to Byte-Pair Encoding. I believe for both the GPT4 and llama tokenizers, only 10% of the vocabularies are single character, the rest are combination of single character tokens.

Is it possible to make it faster by just evaluating single character token and delay the look up until later for the composite ones? I believe similar idea exists in lm-format-enforcer's tokenizer prefix tree

How would you like it to behave?

No response

@brandonwillard
Copy link
Member

brandonwillard commented Nov 21, 2023

We already have an O(vocab) approach in this library and an issue for it: #304. If I'm understanding that tokenizer prefix tree idea that you mentioned, I really don't think it would address the sources of latency in our DFA construction process.

Aside from that, there are quite a few ways to improve the current default approach (e.g. #301), but it should be clear that building these DFAs is a one-time "offline" thing per vocabulary and regex. The most time consuming part of that process is the construction of intermediate typed Numba collection objects for each distinct vocabulary. Those are cached in memory, so it's a one-time cost per session, at least until we set it up to cache those to disk (see #303).

Here's an example illustrating those sources of latency for the IP address example you mentioned:

import time
import timeit
import torch
import numpy as np
import outlines.models as models
import outlines.text.generate as generate


model = models.transformers("gpt2-medium", device="cuda")

prompt = "What is the IP address of the Google DNS servers? "

rng = torch.Generator(device="cuda")
rng.manual_seed(10000)

# We set the max tokens to somewhere near the max achievable with guided
# generation for a somewhat more reasonable comparison
unguided_generator = generate.continuation(model, max_tokens=8)

unguided_generator(prompt, rng=rng)
# "\xa0If you'd like to see a"

unguided_timer = timeit.Timer(
    "unguided_generator(prompt, rng=rng)", "pass", globals=locals()
)
unguided_times = np.array(unguided_timer.repeat(50, 1))

# Initial DFA construction: intermediate model-specific objects are constructed
# and cached in-memory.
start_time = time.perf_counter_ns()
guided_generator = generate.regex(
    model,
    r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
    max_tokens=30,
)
duration = (time.perf_counter_ns() - start_time) / 1e9

print(duration)
# 7.182685829

# Constructing another similar regex shouldn't take as much time now that
# the model-specific intermediate objects are cached
start_time = time.perf_counter_ns()
_ = generate.regex(
    model,
    r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)\d*",
    max_tokens=30,
)
duration = (time.perf_counter_ns() - start_time) / 1e9

print(duration)
# 0.434157875

# Reuse the `guided_generator` object with different prompts to avoid a little
# overhead
guided_generator(prompt, rng=rng)
# '167.45.37.228'

guided_timer = timeit.Timer(
    "guided_generator(prompt, rng=rng)", "pass", globals=locals()
)
guided_times = np.array(guided_timer.repeat(50, 1))

# Guided generation should only incur a fixed cost (e.g. dict look-up) compared
# to unguided, and one that we can still significantly reduce (see
# https://github.com/outlines-dev/outlines/issues/317)
print(np.mean(unguided_times - guided_times))
# -0.009887300101108849

print(np.std(unguided_times - guided_times))
# 0.017138939122497605

print(np.max(unguided_times - guided_times))
# 0.024637886963319033

print(np.min(unguided_times - guided_times))
# -0.06161932600662112

Since we already have open issues addressing the identified sources of latency, I'm going to close this one for now.

@brandonwillard brandonwillard closed this as not planned Won't fix, can't repro, duplicate, stale Nov 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants