Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Generation] Optimize the slow update method in the KVCacheDecoder #1190

Merged
merged 5 commits into from
Aug 24, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 64 additions & 73 deletions src/deepsparse/transformers/utils/decoder_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,79 +103,78 @@ def update(
Corresponds to `input_ids.shape[1]`
"""
self.total_num_processed_tokens += input_ids_len
total_cache_capacity = state[list(state.keys())[0]].shape[

dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
input_state_capacity = state[list(state.keys())[0]].shape[
self._sequence_len_axis
]
# total_capacity = num_tokens (num of non-blank tokens) +
# + num_padded_entries (num of blank tokens)
num_entries_to_delete = input_ids_len

# compute the number of blank (padded) entries in the cache
num_padded_entries = max(
0, total_cache_capacity - self.total_num_processed_tokens
0, input_state_capacity - self.total_num_processed_tokens
)
num_entries_to_delete = input_ids_len
# compute how many of those entries need to be deleted
num_padded_entries_to_delete = min(num_padded_entries, num_entries_to_delete)

if num_padded_entries:
"""
Transforms input KV cache that contains blank entries.
It removes the rightmost blank entries from the cache.
Example 1:
(entries in the cache denote the order in which they were
added to the cache, zero is to denote a blank entry)
```
state["state_name"]: (1, 1, 5, 1) = array([[[[0], [0], [1], [2], [3]]]])
-> num_padded_entries = 2
-> num_entries_to_delete = 1
-> num_padded_entries > num_entries_to_delete
# there are more blank entries than entries to delete
results in:
state["state_name"]: (1, 1, 4, 1) = array([[[[0], [1], [2], [3]]]])
```
Example 2:
```
state["state_name"]: (1, 1, 6, 1) = array([[[[0], [0], [0], [1], [2], [3]]]]) # noqa: E501
-> num_padded_entries = 3
-> num_entries_to_delete = 5
-> num_padded_entries < num_entries_to_delete
# there are less blank entries than entries to delete
results in:
state["state_name"]: (1, 1, 3, 1) = array([[[[1], [2], [3]]]])
```
"""
num_padded_entries_to_delete = min(
num_padded_entries, num_entries_to_delete
)
idxs_to_remove = [
num_padded_entries - i - 1 for i in range(num_padded_entries_to_delete)
]
# if we had fewer blank entries than entries to delete,
# we updated the number of entries to delete to a non-zero value
num_entries_to_delete = max(0, num_entries_to_delete - num_padded_entries)
# update the state of the cache
state = self._delete_entries(state, idxs_to_remove)

if num_entries_to_delete:
"""
Transforms the input KV cache that has been totally
filled with non-blank entries.
Example:
```
state["state_name"]: (1, 1, 5, 1) = array([[[[1], [2], [3], [4], [5]]]])
num_entries_to_delete = 2
if self.freeze_first_position == False:
state["state_name"]: (1, 1, 3, 1) = array([[[[3], [4], [5]]]])
else:

state["state_name"]: (1, 1, 3, 1) = array([[[[1], [4], [5]]]])
```
"""
idxs_to_remove = [
i + int(self._freeze_first_position)
for i in range(num_entries_to_delete)
]

state = self._delete_entries(state, idxs_to_remove)
# if we had fewer padded entries than num_entries_to_delete,
# we additionally are forced to delete some non-padded entries (the oldest ones)
num_non_padded_entries_to_delete = max(
0, num_entries_to_delete - num_padded_entries
)

for name, cache_array in state.items():
if num_padded_entries_to_delete:
cache_array = self.remove_padded_entries(
cache_array, num_padded_entries_to_delete
)
if num_non_padded_entries_to_delete:
cache_array = self.remove_non_padded_entries(
cache_array, num_entries_to_delete
)
state[name] = numpy.ascontiguousarray(cache_array)

self._state = state

def remove_padded_entries(
self, cache_array: numpy.ndarray, num_padded_entries_to_delete: int
):
"""
Remove the num_padded_entries_to_delete entries from the cache array.
This function assumes that the cache_array has the number
of padded (blank) entries that is equal/larger than
num_padded_entries_to_delete.

:param cache_array: The cache array to be modified.
:param num_padded_entries_to_delete: The number of padded entries to delete.
"""
return cache_array[:, :, num_padded_entries_to_delete:, :]

def remove_non_padded_entries(
self, cache_array: numpy.ndarray, num_non_padded_entries_to_delete: int
):
"""
Remove the num_non_padded_entries_to_delete entries from the cache array.
This function assumes that the cache_array has no padded (blank) entries and
thus we are forced to delete the oldest entries from the cache.

If self._freeze_first_position is set to True, that means that the oldest
entry in the cache_array is the one that corresponds to the BOS token. Because
we want to keep that entry in the cache, we will delete the oldest entry
starting from the second oldest entry.
"""
new_cache_array = cache_array[
:,
:,
bool(self._freeze_first_position) + num_non_padded_entries_to_delete :,
:,
]
if self._freeze_first_position:
bos_entries = cache_array[:, :, :1, :]
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
new_cache_array = numpy.concatenate(
[bos_entries, new_cache_array], axis=self._sequence_len_axis
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
)
return new_cache_array

def set_capacity(self, capacity: int):
"""
Enforce a new total capacity for the state
Expand Down Expand Up @@ -212,14 +211,6 @@ def set_capacity(self, capacity: int):

self._state = state

def _delete_entries(
self, state: Dict[str, Any], indices: List[int]
) -> Dict[str, Any]:
for key, value in state.items():
state[key] = numpy.delete(value, indices, axis=self._sequence_len_axis)
state[key] = numpy.ascontiguousarray(state[key])
return state

def _add_entries(
self, state: Dict[str, Any], indices: List[int], padding_value: int = 0
) -> Dict[str, Any]:
Expand Down
Loading