Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup model init on CPU (by 10x+ for llama-3-8B as one example) #31771

Merged
merged 42 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
c3e49a8
1,100%!
muellerzr Jul 3, 2024
e3bcff2
Clean
muellerzr Jul 3, 2024
248910a
Don't touch DS
muellerzr Jul 9, 2024
08df746
Experiment with dtype allocation
muellerzr Jul 10, 2024
f140836
skip test_load_save_without_tied_weights test
SunMarc Jul 10, 2024
b337348
A little faster
muellerzr Jul 10, 2024
9f45f62
Include proper upscaling?
muellerzr Jul 10, 2024
dce912e
Fixup tests
muellerzr Jul 10, 2024
f62d459
Potentially skip?
muellerzr Jul 10, 2024
7ebb3e9
Let's see if this fixes git history
muellerzr Jul 11, 2024
bef3a80
Maintain new dtype
muellerzr Jul 11, 2024
ca1010e
Fin
muellerzr Jul 11, 2024
989612f
Rm hook idea for now
muellerzr Jul 11, 2024
9fc7e8b
New approach, see what breaks
muellerzr Jul 11, 2024
79578ea
stage
muellerzr Jul 11, 2024
639df3b
Clean
muellerzr Jul 11, 2024
cab132b
Stash
muellerzr Jul 12, 2024
8338e2a
Should be fin now, just need to mark failing models
muellerzr Jul 12, 2024
67c52a0
Clean up
muellerzr Jul 12, 2024
2007249
Simplify
muellerzr Jul 12, 2024
6f2e650
Deal with weird models
muellerzr Jul 12, 2024
6cdae65
Enc/Dec
muellerzr Jul 12, 2024
35696f6
Skip w/ reason
muellerzr Jul 12, 2024
0ece40b
Adjust test
muellerzr Jul 12, 2024
6946f86
Fix test
muellerzr Jul 12, 2024
f3f751c
one more test
muellerzr Jul 12, 2024
a7c2a83
Keep experimenting
muellerzr Jul 12, 2024
178cb14
Fix ref
muellerzr Jul 12, 2024
48be6f8
TO REMOVE: testing feedback CI
muellerzr Jul 15, 2024
02c38fe
Right push
muellerzr Jul 15, 2024
74fdf4b
Update tests/utils/test_modeling_utils.py
muellerzr Jul 15, 2024
38d0e89
disable
muellerzr Jul 15, 2024
4335956
Add new func
muellerzr Jul 15, 2024
9c5dc50
Test nits from Amy
muellerzr Jul 16, 2024
c491952
Update src/transformers/modeling_utils.py
muellerzr Jul 16, 2024
fd3890a
Merge branch 'muellerzr-speedup-inference' of https://github.com/hugg…
muellerzr Jul 16, 2024
e8f4a14
Adjust comment
muellerzr Jul 16, 2024
512f34a
Adjust comment on skip
muellerzr Jul 16, 2024
ada401f
make private
muellerzr Jul 16, 2024
1e5466a
Fin
muellerzr Jul 16, 2024
70448cd
Should be a not flag
muellerzr Jul 16, 2024
21af73a
Clarify and rename test
muellerzr Jul 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,31 @@ def dtype_byte_size(dtype):
return bit_size // 8


def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
"""
Checks if `model_to_load` supports param buffer assignment (such
as when loading in empty weights) by first checking
if the model explicitly disables it, then by ensuring that the state dict keys
are a subset of the model's parameters.
"""
if len([key for key in state_dict if key.startswith(start_prefix)]) > 0:
# Some models explicitly do not support param buffer assignment
if hasattr(model_to_load, "supports_param_buffer_assignment"):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
else:
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = list(model_to_load.state_dict().keys())[0]
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
else:
# For cases when the `state_dict` doesn't have any real weights (`albert`)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is weird - what do we mean by "real weights" here? If I look at the safetensors file for a checkpoint it looks like there are weights

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's in reference to test_model_weights_reload_no_missing_tied_weights, in which case we have nuked the saved tensors and as a result they don't exist in the state dict at all etc. In this case we should return False. I've changed it to reference the specific test.

(In most real world cases, we shouldn't get to this point)

return False
return False
muellerzr marked this conversation as resolved.
Show resolved Hide resolved


def shard_checkpoint(
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
):
Expand Down Expand Up @@ -657,7 +682,7 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
return shared_tensors, identical


def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
Expand Down Expand Up @@ -685,8 +710,10 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):

# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix=""):
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers

args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
Expand All @@ -710,9 +737,9 @@ def load(module: nn.Module, state_dict, prefix=""):

for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)

load(model_to_load, state_dict, prefix=start_prefix)
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
Expand Down Expand Up @@ -4018,6 +4045,7 @@ def _fix_key(key):

missing_keys = sorted(set(expected_keys) - set(loaded_keys))
unexpected_keys = set(loaded_keys) - set(expected_keys)

# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
# buffers
model_buffers = {n for n, _ in model.named_buffers()}
Expand Down Expand Up @@ -4252,7 +4280,12 @@ def _find_mismatched_keys(
)
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs = _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
)

else:
# This should always be a list but, just to be sure.
Expand Down Expand Up @@ -4280,6 +4313,7 @@ def _find_mismatched_keys(

if len(resolved_archive_file) > 1:
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
assign_to_params_buffers = None
for shard_file in resolved_archive_file:
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
if shard_file in disk_only_shard_files:
Expand Down Expand Up @@ -4323,7 +4357,14 @@ def _find_mismatched_keys(
)
error_msgs += new_error_msgs
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
# Sharded checkpoint or whole but low_cpu_mem_usage==True
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs += _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
)

# force memory release
del state_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class EncoderDecoderModel(PreTrainedModel):
base_model_prefix = "encoder_decoder"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
supports_param_buffer_assignment = False

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/lxmert/modeling_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ class LxmertPreTrainedModel(PreTrainedModel):
config_class = LxmertConfig
load_tf_weights = load_tf_weights_in_lxmert
base_model_prefix = "lxmert"
supports_param_buffer_assignment = False

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
supports_param_buffer_assignment = False

def __init__(
self,
Expand Down
6 changes: 6 additions & 0 deletions tests/models/bart/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,12 @@ def test_generate_fp16(self):
model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
)
def test_load_save_without_tied_weights(self):
pass


def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
Expand Down
6 changes: 6 additions & 0 deletions tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,12 @@ def test_for_change_to_full_attn(self):

self.assertTrue(torch.allclose(outputs1, outputs2, atol=1e-5))

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass


@require_torch
@require_sentencepiece
Expand Down
12 changes: 12 additions & 0 deletions tests/models/longt5/test_modeling_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,12 @@ def _check_encoder_attention_for_generate(self, attentions, batch_size, config,
[encoder_expected_shape] * len(attentions),
)

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass


@require_torch
class LongT5TGlobalModelTest(LongT5ModelTest):
Expand Down Expand Up @@ -1097,6 +1103,12 @@ def test_attention_outputs(self):
[self.model_tester.num_attention_heads, block_len, 3 * block_len],
)

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass


class LongT5EncoderOnlyTGlobalModelTest(LongT5EncoderOnlyModelTest):
def setUp(self):
Expand Down
6 changes: 6 additions & 0 deletions tests/models/lxmert/test_modeling_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,12 @@ def test_save_load_low_cpu_mem_usage_checkpoints(self):
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass


@require_torch
class LxmertModelIntegrationTest(unittest.TestCase):
Expand Down
6 changes: 6 additions & 0 deletions tests/models/m2m_100/test_modeling_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,12 @@ def test_generate_fp16(self):
model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass


def _long_tensor(tok_lst):
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
Expand Down
6 changes: 6 additions & 0 deletions tests/models/mbart/test_modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@ def test_ensure_weights_are_shared(self):
2,
)

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass


def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
Expand Down
6 changes: 6 additions & 0 deletions tests/models/nllb_moe/test_modeling_nllb_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ def test_get_loss(self):
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass


@require_torch
@require_sentencepiece
Expand Down
6 changes: 6 additions & 0 deletions tests/models/plbart/test_modeling_plbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ def test_generate_fp16(self):
def test_sample_generate(self):
pass

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass


def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
Expand Down
12 changes: 12 additions & 0 deletions tests/models/seamless_m4t/test_modeling_seamless_m4t.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,12 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass

def test_attention_outputs(self):
# expected length is subsampled so need to change a bit this test
if not self.has_attentions:
Expand Down Expand Up @@ -758,6 +764,12 @@ def test_training_gradient_checkpointing_use_reentrant_false(self):
def test_retain_grad_hidden_states_attentions(self):
pass

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass


@require_torch
class SeamlessM4TGenerationTest(unittest.TestCase):
Expand Down
12 changes: 12 additions & 0 deletions tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,12 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass

def test_attention_outputs(self):
# expected length is subsampled so need to change a bit this test
if not self.has_attentions:
Expand Down Expand Up @@ -748,6 +754,12 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass


@require_torch
class SeamlessM4Tv2GenerationTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,12 @@ def test_generate_with_head_masking(self):
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass


class SwitchTransformersEncoderOnlyModelTester:
def __init__(
Expand Down Expand Up @@ -843,6 +849,12 @@ def test_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)

@unittest.skip(
reason="This architecure have tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771"
)
def test_load_save_without_tied_weights(self):
pass


def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task])
Expand Down
32 changes: 17 additions & 15 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import sys
import tempfile
import threading
import time
import unittest
import unittest.mock as mock
import uuid
Expand Down Expand Up @@ -895,31 +896,32 @@ def test_from_pretrained_low_cpu_mem_usage_functional(self):
@require_accelerate
@mark.accelerate_tests
def test_from_pretrained_low_cpu_mem_usage_measured(self):
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
# Before this would test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
# Now though the memory is the same, we simply test that loading with `low_cpu_mem_usage` winds up being *faster*
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

mname = "google-bert/bert-base-cased"

preamble = "from transformers import AutoModel"
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=False)'
max_rss_normal = self.python_one_liner_max_rss(one_liner_str)
start_time = time.time()
# Save this output as `max_rss_normal` if testing memory results
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
_ = self.python_one_liner_max_rss(one_liner_str)
end_time = time.time()
elapsed_time_normal = end_time - start_time
# print(f"{max_rss_normal=}")

one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=True)'
max_rss_low_mem = self.python_one_liner_max_rss(one_liner_str)
# print(f"{max_rss_low_mem=}")

diff_bytes = max_rss_normal - max_rss_low_mem
diff_percent = diff_bytes / max_rss_low_mem
# print(f"{diff_bytes=}, {diff_percent=}")
# ideally we would compare that the diff is close to ~1x checkpoint size in bytes, but
# measuring cpu memory on linux is very tricky and inconsistent, so instead let's check that
# it's at least 15% less cpu memory consumed
start_time = time.time()
# Save this output as `max_rss_low_mem` if testing memory results
_ = self.python_one_liner_max_rss(one_liner_str)
end_time = time.time()
elapsed_time_low_mem = end_time - start_time

self.assertGreater(
diff_percent,
0.15,
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
"should use less CPU memory for low_cpu_mem_usage=True, "
f"but got max_rss_normal={max_rss_normal} and max_rss_low_mem={max_rss_low_mem}",
elapsed_time_normal,
elapsed_time_low_mem,
"using `low_cpu_mem_usage` should be faster, "
f"but got elapsed_time_normal={elapsed_time_normal} and elapsed_time_low_mem={elapsed_time_low_mem}",
)

# if you want to compare things manually, let's first look at the size of the model in bytes
Expand Down
Loading