Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup model init on CPU (by 10x+ for llama-3-8B as one example) #31771

Merged
merged 42 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
c3e49a8
1,100%!
muellerzr Jul 3, 2024
e3bcff2
Clean
muellerzr Jul 3, 2024
248910a
Don't touch DS
muellerzr Jul 9, 2024
08df746
Experiment with dtype allocation
muellerzr Jul 10, 2024
f140836
skip test_load_save_without_tied_weights test
SunMarc Jul 10, 2024
b337348
A little faster
muellerzr Jul 10, 2024
9f45f62
Include proper upscaling?
muellerzr Jul 10, 2024
dce912e
Fixup tests
muellerzr Jul 10, 2024
f62d459
Potentially skip?
muellerzr Jul 10, 2024
7ebb3e9
Let's see if this fixes git history
muellerzr Jul 11, 2024
bef3a80
Maintain new dtype
muellerzr Jul 11, 2024
ca1010e
Fin
muellerzr Jul 11, 2024
989612f
Rm hook idea for now
muellerzr Jul 11, 2024
9fc7e8b
New approach, see what breaks
muellerzr Jul 11, 2024
79578ea
stage
muellerzr Jul 11, 2024
639df3b
Clean
muellerzr Jul 11, 2024
cab132b
Stash
muellerzr Jul 12, 2024
8338e2a
Should be fin now, just need to mark failing models
muellerzr Jul 12, 2024
67c52a0
Clean up
muellerzr Jul 12, 2024
2007249
Simplify
muellerzr Jul 12, 2024
6f2e650
Deal with weird models
muellerzr Jul 12, 2024
6cdae65
Enc/Dec
muellerzr Jul 12, 2024
35696f6
Skip w/ reason
muellerzr Jul 12, 2024
0ece40b
Adjust test
muellerzr Jul 12, 2024
6946f86
Fix test
muellerzr Jul 12, 2024
f3f751c
one more test
muellerzr Jul 12, 2024
a7c2a83
Keep experimenting
muellerzr Jul 12, 2024
178cb14
Fix ref
muellerzr Jul 12, 2024
48be6f8
TO REMOVE: testing feedback CI
muellerzr Jul 15, 2024
02c38fe
Right push
muellerzr Jul 15, 2024
74fdf4b
Update tests/utils/test_modeling_utils.py
muellerzr Jul 15, 2024
38d0e89
disable
muellerzr Jul 15, 2024
4335956
Add new func
muellerzr Jul 15, 2024
9c5dc50
Test nits from Amy
muellerzr Jul 16, 2024
c491952
Update src/transformers/modeling_utils.py
muellerzr Jul 16, 2024
fd3890a
Merge branch 'muellerzr-speedup-inference' of https://github.com/hugg…
muellerzr Jul 16, 2024
e8f4a14
Adjust comment
muellerzr Jul 16, 2024
512f34a
Adjust comment on skip
muellerzr Jul 16, 2024
ada401f
make private
muellerzr Jul 16, 2024
1e5466a
Fin
muellerzr Jul 16, 2024
70448cd
Should be a not flag
muellerzr Jul 16, 2024
21af73a
Clarify and rename test
muellerzr Jul 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ for text generation, [`~generation.GenerationMixin`] (for the PyTorch models),
- push_to_hub
- all

Custom models should also include a `_supports_assign_param_buffer`, which determines if superfast init can apply
on the particular model. Signs that your model needs this are if `test_save_and_load_from_pretrained` fails. If so,
set this to `False`.

## ModuleUtilsMixin

[[autodoc]] modeling_utils.ModuleUtilsMixin
Expand Down
64 changes: 58 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,32 @@ def dtype_byte_size(dtype):
return bit_size // 8


def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
"""
Checks if `model_to_load` supports param buffer assignment (such
as when loading in empty weights) by first checking
if the model explicitly disables it, then by ensuring that the state dict keys
are a subset of the model's parameters.
"""
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False

# Some models explicitly do not support param buffer assignment
if not getattr(model_to_load, "_supports_param_buffer_assignment", False):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False

# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = list(model_to_load.state_dict().keys())[0]
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype

# For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`)
return False


def shard_checkpoint(
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
):
Expand Down Expand Up @@ -657,7 +683,7 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
return shared_tensors, identical


def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
Expand Down Expand Up @@ -685,8 +711,10 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):

# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix=""):
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers

args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
Expand All @@ -710,9 +738,9 @@ def load(module: nn.Module, state_dict, prefix=""):

for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)

load(model_to_load, state_dict, prefix=start_prefix)
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
Expand Down Expand Up @@ -2852,6 +2880,10 @@ def from_pretrained(
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.

If model weights are the same precision as the base model (and is a supported model), weights will be lazily loaded
in using the `meta` device and brought into memory once an input is passed through that layer regardless of
`low_cpu_mem_usage`.

Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
Expand Down Expand Up @@ -2952,7 +2984,13 @@ def from_pretrained(

low_cpu_mem_usage(`bool`, *optional*):
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Generally should be combined with a `device_map` (such as `"auto"`) for best results.
This is an experimental feature and a subject to change at any moment.
</Tip>
If the model weights are in the same precision as the model loaded in, `low_cpu_mem_usage` (without
`device_map`) is redundant and will not provide any benefit in regards to CPU memory usage. However,
this should still be enabled if you are passing in a `device_map`.
</Tip>
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under a specific `dtype`. The different options
are:
Expand Down Expand Up @@ -4018,6 +4056,7 @@ def _fix_key(key):

missing_keys = sorted(set(expected_keys) - set(loaded_keys))
unexpected_keys = set(loaded_keys) - set(expected_keys)

# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
# buffers
model_buffers = {n for n, _ in model.named_buffers()}
Expand Down Expand Up @@ -4252,7 +4291,12 @@ def _find_mismatched_keys(
)
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs = _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
)

else:
# This should always be a list but, just to be sure.
Expand Down Expand Up @@ -4280,6 +4324,7 @@ def _find_mismatched_keys(

if len(resolved_archive_file) > 1:
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
assign_to_params_buffers = None
for shard_file in resolved_archive_file:
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
if shard_file in disk_only_shard_files:
Expand Down Expand Up @@ -4323,7 +4368,14 @@ def _find_mismatched_keys(
)
error_msgs += new_error_msgs
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
# Sharded checkpoint or whole but low_cpu_mem_usage==True
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs += _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
)

# force memory release
del state_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class EncoderDecoderModel(PreTrainedModel):
base_model_prefix = "encoder_decoder"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_supports_param_buffer_assignment = False

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/lxmert/modeling_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ class LxmertPreTrainedModel(PreTrainedModel):
config_class = LxmertConfig
load_tf_weights = load_tf_weights_in_lxmert
base_model_prefix = "lxmert"
_supports_param_buffer_assignment = False

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_supports_param_buffer_assignment = False

def __init__(
self,
Expand Down
6 changes: 6 additions & 0 deletions tests/models/bart/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,12 @@ def test_generate_fp16(self):
model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass


def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
Expand Down
6 changes: 6 additions & 0 deletions tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,12 @@ def test_for_change_to_full_attn(self):

self.assertTrue(torch.allclose(outputs1, outputs2, atol=1e-5))

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass


@require_torch
@require_sentencepiece
Expand Down
12 changes: 12 additions & 0 deletions tests/models/longt5/test_modeling_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,12 @@ def _check_encoder_attention_for_generate(self, attentions, batch_size, config,
[encoder_expected_shape] * len(attentions),
)

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass


@require_torch
class LongT5TGlobalModelTest(LongT5ModelTest):
Expand Down Expand Up @@ -1097,6 +1103,12 @@ def test_attention_outputs(self):
[self.model_tester.num_attention_heads, block_len, 3 * block_len],
)

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass


class LongT5EncoderOnlyTGlobalModelTest(LongT5EncoderOnlyModelTest):
def setUp(self):
Expand Down
6 changes: 6 additions & 0 deletions tests/models/lxmert/test_modeling_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,12 @@ def test_save_load_low_cpu_mem_usage_checkpoints(self):
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass


@require_torch
class LxmertModelIntegrationTest(unittest.TestCase):
Expand Down
6 changes: 6 additions & 0 deletions tests/models/m2m_100/test_modeling_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,12 @@ def test_generate_fp16(self):
model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass


def _long_tensor(tok_lst):
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
Expand Down
6 changes: 6 additions & 0 deletions tests/models/mbart/test_modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@ def test_ensure_weights_are_shared(self):
2,
)

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass


def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
Expand Down
6 changes: 6 additions & 0 deletions tests/models/nllb_moe/test_modeling_nllb_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ def test_get_loss(self):
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass


@require_torch
@require_sentencepiece
Expand Down
6 changes: 6 additions & 0 deletions tests/models/plbart/test_modeling_plbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ def test_generate_fp16(self):
def test_sample_generate(self):
pass

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass


def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
Expand Down
12 changes: 12 additions & 0 deletions tests/models/seamless_m4t/test_modeling_seamless_m4t.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,12 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass

def test_attention_outputs(self):
# expected length is subsampled so need to change a bit this test
if not self.has_attentions:
Expand Down Expand Up @@ -758,6 +764,12 @@ def test_training_gradient_checkpointing_use_reentrant_false(self):
def test_retain_grad_hidden_states_attentions(self):
pass

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass


@require_torch
class SeamlessM4TGenerationTest(unittest.TestCase):
Expand Down
12 changes: 12 additions & 0 deletions tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,12 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass

def test_attention_outputs(self):
# expected length is subsampled so need to change a bit this test
if not self.has_attentions:
Expand Down Expand Up @@ -748,6 +754,12 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass


@require_torch
class SeamlessM4Tv2GenerationTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,12 @@ def test_generate_with_head_masking(self):
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass


class SwitchTransformersEncoderOnlyModelTester:
def __init__(
Expand Down Expand Up @@ -843,6 +849,12 @@ def test_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)

@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass


def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task])
Expand Down
Loading
Loading