Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add act checkpoint at sub layer level #720

Merged
merged 15 commits into from
Nov 13, 2023
24 changes: 22 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias
from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY,
attn_bias_shape,
build_attn_bias)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
Expand Down Expand Up @@ -705,7 +707,25 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool:

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
return isinstance(module, MPTBlock)
if not hasattr(self.config, 'activation_checkpointing_target'):
return isinstance(module, MPTBlock)
cli99 marked this conversation as resolved.
Show resolved Hide resolved
act_ckpt_str = self.config.activation_checkpointing_target
act_ckpt_lst = act_ckpt_str.replace(' ', '').split(',')
cli99 marked this conversation as resolved.
Show resolved Hide resolved
cli99 marked this conversation as resolved.
Show resolved Hide resolved
cli99 marked this conversation as resolved.
Show resolved Hide resolved
if act_ckpt_lst:
if 'MPTBlock' in act_ckpt_lst or 'mptblock' in act_ckpt_lst:
cli99 marked this conversation as resolved.
Show resolved Hide resolved
act_ckpt_lst = ['MPTBlock']
cli99 marked this conversation as resolved.
Show resolved Hide resolved
for mod_name in act_ckpt_lst:
if mod_name.lower() == 'mptblock':
mod_type = MPTBlock
elif mod_name in ATTN_CLASS_REGISTRY:
mod_type = ATTN_CLASS_REGISTRY[mod_name]
elif mod_name in FFN_CLASS_REGISTRY:
mod_type = FFN_CLASS_REGISTRY[mod_name]
elif mod_name in NORM_CLASS_REGISTRY:
mod_type = NORM_CLASS_REGISTRY[mod_name]
else:
continue
cli99 marked this conversation as resolved.
Show resolved Hide resolved
return isinstance(module, mod_type)
cli99 marked this conversation as resolved.
Show resolved Hide resolved

def prepare_inputs_for_generation(
self,
Expand Down
Loading