-
Notifications
You must be signed in to change notification settings - Fork 418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tokenizer Aware FSM Construction #383
Comments
We already have an O(vocab) approach in this library and an issue for it: #304. If I'm understanding that tokenizer prefix tree idea that you mentioned, I really don't think it would address the sources of latency in our DFA construction process. Aside from that, there are quite a few ways to improve the current default approach (e.g. #301), but it should be clear that building these DFAs is a one-time "offline" thing per vocabulary and regex. The most time consuming part of that process is the construction of intermediate typed Numba collection objects for each distinct vocabulary. Those are cached in memory, so it's a one-time cost per session, at least until we set it up to cache those to disk (see #303). Here's an example illustrating those sources of latency for the IP address example you mentioned: import time
import timeit
import torch
import numpy as np
import outlines.models as models
import outlines.text.generate as generate
model = models.transformers("gpt2-medium", device="cuda")
prompt = "What is the IP address of the Google DNS servers? "
rng = torch.Generator(device="cuda")
rng.manual_seed(10000)
# We set the max tokens to somewhere near the max achievable with guided
# generation for a somewhat more reasonable comparison
unguided_generator = generate.continuation(model, max_tokens=8)
unguided_generator(prompt, rng=rng)
# "\xa0If you'd like to see a"
unguided_timer = timeit.Timer(
"unguided_generator(prompt, rng=rng)", "pass", globals=locals()
)
unguided_times = np.array(unguided_timer.repeat(50, 1))
# Initial DFA construction: intermediate model-specific objects are constructed
# and cached in-memory.
start_time = time.perf_counter_ns()
guided_generator = generate.regex(
model,
r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
max_tokens=30,
)
duration = (time.perf_counter_ns() - start_time) / 1e9
print(duration)
# 7.182685829
# Constructing another similar regex shouldn't take as much time now that
# the model-specific intermediate objects are cached
start_time = time.perf_counter_ns()
_ = generate.regex(
model,
r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)\d*",
max_tokens=30,
)
duration = (time.perf_counter_ns() - start_time) / 1e9
print(duration)
# 0.434157875
# Reuse the `guided_generator` object with different prompts to avoid a little
# overhead
guided_generator(prompt, rng=rng)
# '167.45.37.228'
guided_timer = timeit.Timer(
"guided_generator(prompt, rng=rng)", "pass", globals=locals()
)
guided_times = np.array(guided_timer.repeat(50, 1))
# Guided generation should only incur a fixed cost (e.g. dict look-up) compared
# to unguided, and one that we can still significantly reduce (see
# https://github.com/outlines-dev/outlines/issues/317)
print(np.mean(unguided_times - guided_times))
# -0.009887300101108849
print(np.std(unguided_times - guided_times))
# 0.017138939122497605
print(np.max(unguided_times - guided_times))
# 0.024637886963319033
print(np.min(unguided_times - guided_times))
# -0.06161932600662112 Since we already have open issues addressing the identified sources of latency, I'm going to close this one for now. |
What behavior of the library made you think about the improvement?
Please correct me if I'm wrong, currently the FSM construction code is O(vocab)xO(state). Even with numba's acceleration it is still pretty time consuming. For example, the IP address example took 9s for me.
One observation, the tokenizer's vocabulary is often times way bigger than the collection of characters due to Byte-Pair Encoding. I believe for both the GPT4 and llama tokenizers, only 10% of the vocabularies are single character, the rest are combination of single character tokens.
Is it possible to make it faster by just evaluating single character token and delay the look up until later for the composite ones? I believe similar idea exists in lm-format-enforcer's tokenizer prefix tree
How would you like it to behave?
No response
The text was updated successfully, but these errors were encountered: