Skip to content

Commit

Permalink
Add torch.compile Support For Mamba (huggingface#31247)
Browse files Browse the repository at this point in the history
* modify mamba cache

* set up cache

* add test

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* use_cache_position

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* fix

* cache in generate

* [run-slow] mamba

* address comments

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix cache name

* [run-slow] mamba
  • Loading branch information
zhenglongjiepheonix authored and MHRDYN7 committed Jul 23, 2024
1 parent 7315305 commit cbe2713
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 85 deletions.
74 changes: 74 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,3 +1249,77 @@ def reset(self):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()


class MambaCache:
"""
Cache for mamba model which does not have attention mechanism and key value states.
Arguments:
config: MambaConfig
max_batch_size: int
dtype: torch.dtype
device: torch.device
Attributes:
dtype: torch.dtype
intermediate_size: int
ssm_state_size: int
conv_kernel_size: int
conv_states: torch.Tensor [layer_idx, batch_size, intermediate_size, conv_kernel_size]
ssm_states: torch.Tensor [layer_idx, batch_size, intermediate_size, ssm_state_size]
"""

def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
dtype: torch.dtype = torch.float16,
device: Optional[str] = None,
**kwargs,
):
self.dtype = dtype
self.max_batch_size = max_batch_size
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel

self.conv_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
dtype=dtype,
)
self.ssm_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
dtype=dtype,
)

torch._dynamo.mark_static_address(self.conv_states)
torch._dynamo.mark_static_address(self.ssm_states)

def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)

conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]

def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
return self.ssm_states[layer_idx]

def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()
29 changes: 20 additions & 9 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
EncoderDecoderCache,
HQQQuantizedCache,
HybridCache,
MambaCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
Expand Down Expand Up @@ -116,7 +117,12 @@
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache, "hybrid": HybridCache}
NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
"sliding_window": SlidingWindowCache,
"hybrid": HybridCache,
"mamba": MambaCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}


Expand Down Expand Up @@ -1431,8 +1437,9 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l
not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.max_batch_size != max_batch_size
or cache_to_check.max_cache_len < max_cache_len
)
if cache_implementation != "mamba":
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len

if requires_cross_attention_cache and hasattr(self, "_cache"):
need_new_cache = (
Expand Down Expand Up @@ -1750,9 +1757,13 @@ def generate(
)

use_dynamic_cache_by_default = False
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
if "mamba" in self.__class__.__name__.lower():
cache_name = "cache_params"
else:
cache_name = "past_key_values"
if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None):
raise ValueError(
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
"Cache object) is unsupported. Please use only one of the two."
)
elif generation_config.cache_implementation is not None:
Expand All @@ -1762,7 +1773,7 @@ def generate(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_cache(
model_kwargs[cache_name] = self._get_cache(
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size,
generation_config.max_length,
Expand Down Expand Up @@ -1793,23 +1804,23 @@ def generate(
"Please install it via with `pip install hqq`"
)

model_kwargs["past_key_values"] = cache_class(cache_config)
model_kwargs[cache_name] = cache_class(cache_config)
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
past = model_kwargs.get("past_key_values", None)
past = model_kwargs.get(cache_name, None)
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
if past is None:
model_kwargs["past_key_values"] = (
model_kwargs[cache_name] = (
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
use_dynamic_cache_by_default = True
elif isinstance(past, tuple):
model_kwargs["past_key_values"] = (
model_kwargs[cache_name] = (
DynamicCache.from_legacy_cache(past)
if not requires_cross_attention_cache
else EncoderDecoderCache.from_legacy_cache(past)
Expand Down
Loading

0 comments on commit cbe2713

Please sign in to comment.