Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fully configurable activation checkpointing #951

Merged
merged 19 commits into from
Feb 8, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 95 additions & 21 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,32 +908,21 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool:

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
act_ckpt_list = getattr(self.config, 'activation_checkpointing_target',
None) or ['MPTBlock']
if isinstance(act_ckpt_list, str):
act_ckpt_list = [act_ckpt_list]
elif not isinstance(act_ckpt_list, list):
raise ValueError(
f'activation_checkpointing_target must be either a single string or a list, but got {type(act_ckpt_list)}'
if not hasattr(module, 'block_idx'):
log.debug(
f'No activating checkpointing for {module.__class__.__name__}, only transformer block or its submodules are eligible for activation checkpointing.'
cli99 marked this conversation as resolved.
Show resolved Hide resolved
)
return False

if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list:
if len(act_ckpt_list) > 1:
log.info(
'Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).'
)
return isinstance(module, MPTBlock)

mod_types = ()
for mod_name in act_ckpt_list:
def get_act_ckpt_module(mod_name: str) -> nn.Module:
if mod_name.lower() == 'mptblock':
mod_types += (MPTBlock,)
mod_type = MPTBlock
elif mod_name in ATTN_CLASS_REGISTRY:
mod_types += (ATTN_CLASS_REGISTRY[mod_name],)
mod_type = ATTN_CLASS_REGISTRY[mod_name]
elif mod_name in FFN_CLASS_REGISTRY:
mod_types += (FFN_CLASS_REGISTRY[mod_name],)
mod_type = FFN_CLASS_REGISTRY[mod_name]
elif mod_name in NORM_CLASS_REGISTRY:
mod_types += (NORM_CLASS_REGISTRY[mod_name],)
mod_type = NORM_CLASS_REGISTRY[mod_name]
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) +
Expand All @@ -942,7 +931,92 @@ def activation_checkpointing_fn(self, module: nn.Module) -> bool:
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
return isinstance(module, mod_types)
return mod_type

def get_target_block_list(target_blocks, max_block_idx) -> list:

def parse_ele_str(ele: str, max_block_idx: int) -> list:
to_add = None
if ele.startswith('first-'):
assert ele[6:].isdigit(
), f'Invalid target_blocks element {ele}'
to_add = list(range(min(int(ele[6:]), max_block_idx + 1)))
elif ele.startswith('last-'):
assert ele[5:].isdigit(
), f'Invalid target_blocks element {ele}'
to_add = list(
range(max(max_block_idx - int(ele[5:]) + 1, 0),
max_block_idx + 1))
elif ele.startswith('middle-'):
assert ele[7:].isdigit(
), f'Invalid target_blocks element {ele}'
num = int(ele[7:])
start = max(max_block_idx // 2 - num // 2, 0)
end = min(start + num, max_block_idx + 1)
to_add = list(range(start, end))
else:
raise ValueError(f'Invalid target_blocks element {ele}')
return to_add

candidate_block_ids = []
if isinstance(target_blocks, int):
candidate_block_ids = list(range(target_blocks))
elif isinstance(target_blocks, list):
for ele in target_blocks:
if isinstance(ele, int):
candidate_block_ids.append(ele)
elif isinstance(ele, str):
to_add = parse_ele_str(ele, max_block_idx)
candidate_block_ids.extend(to_add)
else:
raise ValueError(
f'target_blocks must be a list of integers or "fist-n", "last-m" or "middle-k" where n, m, k are integers, but got {target_blocks}'
)
elif isinstance(target_blocks, str):
target_blocks = target_blocks.replace(' ', '')
for ele in target_blocks.split(','):
to_add = parse_ele_str(ele, max_block_idx)
candidate_block_ids.extend(to_add)
else:
raise ValueError(
f'target_blocks must be either a single integer or a list of integers or a comma separated string made of "fist-n", "last-m" or "middle-k" where n, m, k are integers, but got {type(target_blocks)}'
)

candidate_block_ids = list(set(candidate_block_ids))
return candidate_block_ids

act_ckpt_target = getattr(self.config,
'activation_checkpointing_target', None)
act_ckpt_mod_to_blocks = {}
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
if act_ckpt_target is None:
mod = MPTBlock
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, str):
mod = get_act_ckpt_module(target)
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, list):
for target in act_ckpt_target:
mod = get_act_ckpt_module(target)
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, dict):
for k, v in act_ckpt_target.items():
mod = get_act_ckpt_module(k)
block_ids = get_target_block_list(v, module.max_block_idx)
act_ckpt_mod_to_blocks[mod] = block_ids
log.info(
f'for module {mod.__name__}, target_blocks is set as {v}, activation checkpointing is applied to {block_ids} blocks.'
cli99 marked this conversation as resolved.
Show resolved Hide resolved
)
else:
raise ValueError(
f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}'
)

for k in act_ckpt_mod_to_blocks.keys():
if isinstance(module, k):
blocks = act_ckpt_mod_to_blocks[k]
return True if blocks == -1 else module.block_idx in blocks

return False

def prepare_inputs_for_generation(
self,
Expand Down
Loading