Skip to content

Commit

Permalink
change DynamicCache function names from "split" to "batch_split" for …
Browse files Browse the repository at this point in the history
…readability + apply coding style
  • Loading branch information
Cyrilvallez committed May 21, 2024
1 parent e4ad53a commit 93cd2ea
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
6 changes: 3 additions & 3 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def crop(self, maximum_length: int):
self.key_cache[idx] = self.key_cache[idx][..., :maximum_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :maximum_length, :]

def split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
Expand All @@ -214,8 +214,8 @@ def split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
return out

@classmethod
def from_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
"""This is the opposite of the above `split()` method. This will be used by `stack_model_outputs` in
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls()
for idx in range(len(splits[0])):
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,13 @@ def __init__(
assistant_kwargs[key] = (
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
)

# Remove potential default DynamicCache if assistant does not support it
if "past_key_values" in assistant_kwargs.keys():
if isinstance(assistant_kwargs["past_key_values"], DynamicCache) and not self.assistant_model._supports_dynamic_cache_class:
if (
isinstance(assistant_kwargs["past_key_values"], DynamicCache)
and not self.assistant_model._supports_dynamic_cache_class
):
# Cache is empty -> remove it from kwargs
if len(assistant_kwargs["past_key_values"]) == 0:
del assistant_kwargs["past_key_values"]
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3793,9 +3793,9 @@ def _split(data, full_batch_size: int, split_size: int = None):
return [None] * (full_batch_size // split_size)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
# New efficient cache
# New cache format
elif isinstance(data, DynamicCache):
return data.split(full_batch_size, split_size)
return data.batch_split(full_batch_size, split_size)
elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0], tuple):
Expand Down Expand Up @@ -3899,9 +3899,9 @@ def _concat(data):
return None
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
# New efficient cache
# New cache format
elif isinstance(data[0], DynamicCache):
return DynamicCache.from_splits(data)
return DynamicCache.from_batch_splits(data)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):
Expand Down

0 comments on commit 93cd2ea

Please sign in to comment.