Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.compile Support For Mamba #31247

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
3f9eeb8
modify mamba cache
zhenglongjiepheonix Jun 4, 2024
d6413fe
set up cache
zhenglongjiepheonix Jun 5, 2024
054d7cf
add test
zhenglongjiepheonix Jun 6, 2024
93fa82f
fix conflict
zhenglongjiepheonix Jun 6, 2024
48e0ff0
[run-slow] mamba
zhenglongjiepheonix Jun 6, 2024
1e6f0e9
[run-slow] mamba
zhenglongjiepheonix Jun 6, 2024
23d383a
Merge remote-tracking branch 'upstream/main' into longjie/add_torch_c…
zhenglongjiepheonix Jun 7, 2024
49ee4cb
address comments
zhenglongjiepheonix Jun 9, 2024
bc85aa9
Merge remote-tracking branch 'upstream/main' into longjie/add_torch_c…
zhenglongjiepheonix Jun 9, 2024
247e789
[run-slow] mamba
zhenglongjiepheonix Jun 9, 2024
9e1fb0e
use_cache_position
zhenglongjiepheonix Jun 10, 2024
8a132ac
Merge remote-tracking branch 'upstream/main' into longjie/add_torch_c…
zhenglongjiepheonix Jun 10, 2024
100b999
[run-slow] mamba
zhenglongjiepheonix Jun 10, 2024
2dc8986
[run-slow] mamba
zhenglongjiepheonix Jun 10, 2024
9732a62
Merge remote-tracking branch 'upstream/main' into longjie/add_torch_c…
zhenglongjiepheonix Jun 11, 2024
c2c8e5d
[run-slow] mamba
zhenglongjiepheonix Jun 11, 2024
9120fbf
Merge remote-tracking branch 'upstream/main' into longjie/add_torch_c…
zhenglongjiepheonix Jun 12, 2024
63568df
[run-slow] mamba
zhenglongjiepheonix Jun 12, 2024
0736da8
fix
zhenglongjiepheonix Jun 17, 2024
3141d26
Merge remote-tracking branch 'upstream/main' into longjie/add_torch_c…
zhenglongjiepheonix Jun 17, 2024
584528e
Merge remote-tracking branch 'upstream/main' into longjie/add_torch_c…
zhenglongjiepheonix Jun 20, 2024
f136d3c
cache in generate
zhenglongjiepheonix Jun 20, 2024
ec4a7a3
Merge remote-tracking branch 'upstream/main' into longjie/add_torch_c…
zhenglongjiepheonix Jun 20, 2024
400feec
[run-slow] mamba
zhenglongjiepheonix Jun 20, 2024
de9182d
address comments
zhenglongjiepheonix Jun 27, 2024
38441cd
resolve conflict
zhenglongjiepheonix Jun 27, 2024
a254a09
[run-slow] mamba
zhenglongjiepheonix Jun 27, 2024
ac456ed
[run-slow] mamba
zhenglongjiepheonix Jun 27, 2024
3e95813
Merge remote-tracking branch 'upstream/main' into longjie/add_torch_c…
zhenglongjiepheonix Jul 1, 2024
801e8d1
address comments
zhenglongjiepheonix Jul 1, 2024
b8af2a3
[run-slow] mamba
zhenglongjiepheonix Jul 1, 2024
b2e8a0b
fix
zhenglongjiepheonix Jul 8, 2024
708d302
fix
zhenglongjiepheonix Jul 8, 2024
97b1add
[run-slow] mamba
zhenglongjiepheonix Jul 8, 2024
2feeeb0
fix
zhenglongjiepheonix Jul 8, 2024
b045c38
fix conflict
zhenglongjiepheonix Jul 13, 2024
fcdc98f
[run-slow] mamba
zhenglongjiepheonix Jul 13, 2024
82b0c9b
fix cache name
zhenglongjiepheonix Jul 16, 2024
611e0d7
Merge remote-tracking branch 'upstream/main' into longjie/add_torch_c…
zhenglongjiepheonix Jul 16, 2024
bc1563e
Merge remote-tracking branch 'upstream/main' into longjie/add_torch_c…
zhenglongjiepheonix Jul 16, 2024
78c8c1c
[run-slow] mamba
zhenglongjiepheonix Jul 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,3 +1249,77 @@ def reset(self):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()


class MambaCache:
"""
Cache for mamba model which does not have attention mechanism and key value states.

Arguments:
config: MambaConfig
max_batch_size: int
dtype: torch.dtype
device: torch.device

Attributes:
dtype: torch.dtype
intermediate_size: int
ssm_state_size: int
conv_kernel_size: int
conv_states: torch.Tensor [layer_idx, batch_size, intermediate_size, conv_kernel_size]
ssm_states: torch.Tensor [layer_idx, batch_size, intermediate_size, ssm_state_size]
"""

def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
dtype: torch.dtype = torch.float16,
device: Optional[str] = None,
**kwargs,
):
self.dtype = dtype
self.max_batch_size = max_batch_size
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel

self.conv_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
dtype=dtype,
)
self.ssm_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
dtype=dtype,
)

torch._dynamo.mark_static_address(self.conv_states)
torch._dynamo.mark_static_address(self.ssm_states)

def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)

conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]

def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
return self.ssm_states[layer_idx]

def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()
29 changes: 20 additions & 9 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
EncoderDecoderCache,
HQQQuantizedCache,
HybridCache,
MambaCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
Expand Down Expand Up @@ -116,7 +117,12 @@
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache, "hybrid": HybridCache}
NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
"sliding_window": SlidingWindowCache,
"hybrid": HybridCache,
"mamba": MambaCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}


Expand Down Expand Up @@ -1431,8 +1437,9 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l
not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.max_batch_size != max_batch_size
or cache_to_check.max_cache_len < max_cache_len
)
if cache_implementation != "mamba":
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
zhenglongjiepheonix marked this conversation as resolved.
Show resolved Hide resolved

if requires_cross_attention_cache and hasattr(self, "_cache"):
need_new_cache = (
Expand Down Expand Up @@ -1750,9 +1757,13 @@ def generate(
)

use_dynamic_cache_by_default = False
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
if "mamba" in self.__class__.__name__.lower():
cache_name = "cache_params"
else:
cache_name = "past_key_values"
if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None):
raise ValueError(
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
"Cache object) is unsupported. Please use only one of the two."
)
elif generation_config.cache_implementation is not None:
Expand All @@ -1762,7 +1773,7 @@ def generate(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_cache(
model_kwargs[cache_name] = self._get_cache(
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size,
generation_config.max_length,
Expand Down Expand Up @@ -1793,23 +1804,23 @@ def generate(
"Please install it via with `pip install hqq`"
)

model_kwargs["past_key_values"] = cache_class(cache_config)
model_kwargs[cache_name] = cache_class(cache_config)
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
past = model_kwargs.get("past_key_values", None)
past = model_kwargs.get(cache_name, None)
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
if past is None:
model_kwargs["past_key_values"] = (
model_kwargs[cache_name] = (
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
use_dynamic_cache_by_default = True
elif isinstance(past, tuple):
model_kwargs["past_key_values"] = (
model_kwargs[cache_name] = (
DynamicCache.from_legacy_cache(past)
if not requires_cross_attention_cache
else EncoderDecoderCache.from_legacy_cache(past)
Expand Down
Loading
Loading