Skip to content

Commit

Permalink
internal kv_cache update for batch_size > 1 (#1514)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Jan 5, 2024
1 parent f2530e3 commit 05a47b0
Showing 1 changed file with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def run(self, inp: NLEngineInputs, **kwargs) -> NLEngineOutputs:
inputs = list(map(self._add_kv_cache_to_input, engine_input, kv_cache))
inputs = join_engine_outputs(inputs, len(inputs))

if bool(kv_cache[0].engine_internal_cache):
internal_kv_cache_present = bool(kv_cache[0].engine_internal_cache)

if internal_kv_cache_present:
# conventionally, before dispatching
# inputs to the engine, we validate them
# if val_inp=True. However, in this case
Expand All @@ -235,18 +237,21 @@ def run(self, inp: NLEngineInputs, **kwargs) -> NLEngineOutputs:
)

# logits should be stacked along batch dim
# kv_cache_state should be a list where each dim 0 is batch_size
# kv_cache_state should be a list where each item has dim 0 as batch_size
logits, *kv_cache_state = out
kv_cache_state, _ = split_engine_inputs(kv_cache_state, 1)

if len(kv_cache_state) > 0:
if not internal_kv_cache_present:
# split along batch sizes; will give a list of lists where number of lists
# is equal to batch_size
kv_cache_state, _ = split_engine_inputs(kv_cache_state, 1)
for i in range(len(kv_cache)):
# pass in a list and kv_cache object per _update_kv_cache call
self._update_kv_cache(
kv_cache_state=kv_cache_state[i], kv_cache=kv_cache[i]
)
else:
# internal kv cache case
self._update_kv_cache(kv_cache=kv_cache[0])
for i in range(len(kv_cache)):
self._update_kv_cache(kv_cache=kv_cache[i])

output = {
"engine_outputs": logits,
Expand Down

0 comments on commit 05a47b0

Please sign in to comment.