Skip to content

Commit

Permalink
[Text Generation] Optimize the slow update method in the KVCacheDec…
Browse files Browse the repository at this point in the history
…oder (#1190)

* initial commit

* Nit: docstring typo

* fix style
  • Loading branch information
dbogunowicz committed Aug 24, 2023
1 parent 2cf112a commit 1bd60d2
Showing 1 changed file with 65 additions and 74 deletions.
139 changes: 65 additions & 74 deletions src/deepsparse/transformers/utils/decoder_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def update(
):
"""
Updating the session is identical with taking the kv cache
output of from the forward pass and restructuring it, so it
output from the forward pass and restructuring it, so it
can be directly used as input for the next forward pass.
:param state: The state of the cache. This is a dictionary
Expand All @@ -103,79 +103,78 @@ def update(
Corresponds to `input_ids.shape[1]`
"""
self.total_num_processed_tokens += input_ids_len
total_cache_capacity = state[list(state.keys())[0]].shape[

input_state_capacity = state[list(state.keys())[0]].shape[
self._sequence_len_axis
]
# total_capacity = num_tokens (num of non-blank tokens) +
# + num_padded_entries (num of blank tokens)
num_entries_to_delete = input_ids_len

# compute the number of blank (padded) entries in the cache
num_padded_entries = max(
0, total_cache_capacity - self.total_num_processed_tokens
0, input_state_capacity - self.total_num_processed_tokens
)
num_entries_to_delete = input_ids_len
# compute how many of those entries need to be deleted
num_padded_entries_to_delete = min(num_padded_entries, num_entries_to_delete)

if num_padded_entries:
"""
Transforms input KV cache that contains blank entries.
It removes the rightmost blank entries from the cache.
Example 1:
(entries in the cache denote the order in which they were
added to the cache, zero is to denote a blank entry)
```
state["state_name"]: (1, 1, 5, 1) = array([[[[0], [0], [1], [2], [3]]]])
-> num_padded_entries = 2
-> num_entries_to_delete = 1
-> num_padded_entries > num_entries_to_delete
# there are more blank entries than entries to delete
results in:
state["state_name"]: (1, 1, 4, 1) = array([[[[0], [1], [2], [3]]]])
```
Example 2:
```
state["state_name"]: (1, 1, 6, 1) = array([[[[0], [0], [0], [1], [2], [3]]]]) # noqa: E501
-> num_padded_entries = 3
-> num_entries_to_delete = 5
-> num_padded_entries < num_entries_to_delete
# there are less blank entries than entries to delete
results in:
state["state_name"]: (1, 1, 3, 1) = array([[[[1], [2], [3]]]])
```
"""
num_padded_entries_to_delete = min(
num_padded_entries, num_entries_to_delete
)
idxs_to_remove = [
num_padded_entries - i - 1 for i in range(num_padded_entries_to_delete)
]
# if we had fewer blank entries than entries to delete,
# we updated the number of entries to delete to a non-zero value
num_entries_to_delete = max(0, num_entries_to_delete - num_padded_entries)
# update the state of the cache
state = self._delete_entries(state, idxs_to_remove)

if num_entries_to_delete:
"""
Transforms the input KV cache that has been totally
filled with non-blank entries.
Example:
```
state["state_name"]: (1, 1, 5, 1) = array([[[[1], [2], [3], [4], [5]]]])
num_entries_to_delete = 2
if self.freeze_first_position == False:
state["state_name"]: (1, 1, 3, 1) = array([[[[3], [4], [5]]]])
else:
state["state_name"]: (1, 1, 3, 1) = array([[[[1], [4], [5]]]])
```
"""
idxs_to_remove = [
i + int(self._freeze_first_position)
for i in range(num_entries_to_delete)
]

state = self._delete_entries(state, idxs_to_remove)
# if we had fewer padded entries than num_entries_to_delete,
# we additionally are forced to delete some non-padded entries (the oldest ones)
num_non_padded_entries_to_delete = max(
0, num_entries_to_delete - num_padded_entries
)

for name, cache_array in state.items():
if num_padded_entries_to_delete:
cache_array = self.remove_padded_entries(
cache_array, num_padded_entries_to_delete
)
if num_non_padded_entries_to_delete:
cache_array = self.remove_non_padded_entries(
cache_array, num_entries_to_delete
)
state[name] = numpy.ascontiguousarray(cache_array)

self._state = state

def remove_padded_entries(
self, cache_array: numpy.ndarray, num_padded_entries_to_delete: int
):
"""
Remove the num_padded_entries_to_delete entries from the cache array.
This function assumes that the cache_array has the number
of padded (blank) entries that is equal/larger than
num_padded_entries_to_delete.
:param cache_array: The cache array to be modified.
:param num_padded_entries_to_delete: The number of padded entries to delete.
"""
return cache_array[:, :, num_padded_entries_to_delete:, :]

def remove_non_padded_entries(
self, cache_array: numpy.ndarray, num_non_padded_entries_to_delete: int
):
"""
Remove the num_non_padded_entries_to_delete entries from the cache array.
This function assumes that the cache_array has no padded (blank) entries and
thus we are forced to delete the oldest entries from the cache.
If self._freeze_first_position is set to True, that means that the oldest
entry in the cache_array is the one that corresponds to the BOS token. Because
we want to keep that entry in the cache, we will delete the oldest entry
starting from the second oldest entry.
"""
new_cache_array = cache_array[
:,
:,
bool(self._freeze_first_position) + num_non_padded_entries_to_delete :,
:,
]
if self._freeze_first_position:
bos_entries = cache_array[:, :, :1, :]
new_cache_array = numpy.concatenate(
[bos_entries, new_cache_array], axis=self._sequence_len_axis
)
return new_cache_array

def set_capacity(self, capacity: int):
"""
Enforce a new total capacity for the state
Expand Down Expand Up @@ -212,14 +211,6 @@ def set_capacity(self, capacity: int):

self._state = state

def _delete_entries(
self, state: Dict[str, Any], indices: List[int]
) -> Dict[str, Any]:
for key, value in state.items():
state[key] = numpy.delete(value, indices, axis=self._sequence_len_axis)
state[key] = numpy.ascontiguousarray(state[key])
return state

def _add_entries(
self, state: Dict[str, Any], indices: List[int], padding_value: int = 0
) -> Dict[str, Any]:
Expand Down

0 comments on commit 1bd60d2

Please sign in to comment.