Skip to content

Commit

Permalink
Extend save_pretrained to offloaded models (huggingface#27412)
Browse files Browse the repository at this point in the history
* added hidden subset

* debugged hidden subset contrastive search

* added contrastive search compression

* debugged compressed contrastive search

* memory reduction for contrastive search

* debugged mem red

* added low memory option feature

* debugged mem optmimization output stack

* debugged mem optmimization output stack

* debugged low mem

* added low mem cache

* fixed 2047 tensor view

* debugged 2042 past key val inputs

* reformatted tensors

* changed low mem output

* final clean

* removed subset hidden csearch

* fixed hidden device

* fixed hidden device

* changed compressor dtype

* removed hstate compression

* integrated csearch in generate

* test csearch integration into generation

exit()

* fixed csearch kwarg integration with generation

* final wrap and added doc

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* added debug print

* direct hstate cat

* direct hstate cat

* direct hstate cat debug

* direct hstate cat debug

* expanded full hidden state stack

* expanded full hidden state stack

* matched dims for hstates

* matched dims for hstates

* logits fix

* equality test

* equality hidden debug

* debug

* added prints for debug

* added prints for debug

* equality check

* switched squeeze dim

* input format debug

* tracing top_k_ids

* removed trace

* added test context

* added jitter

* added jitter

* added jitter

* returned state

* rebuilt past key value reconstruction

* debugged

* cleaned traces

* added selection for pkv

* changed output to dict

* cleaned

* cleaned

* cleaned up contrastive search test

* moved low_memory kwarg

* debugged

* changed low mem test batch size to 1

* removed output

* debugged test input shape

* reformatted csearch test

* added trace

* removed unsqueeze on final forward pass

* replaced unsqueeze with view

* removed traces

* cleaned

* debugged model kwargs

* removed special models from test

* ran make quality

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* refactored

* refactored

* refactored

* make fixup

* renamed flag sequential

* renamed flag sequential

* iterative onloading

* black style and test utils

* added traces for integrated test

* debugged

* added traces

* make style

* removed traces, make style

* included suggestions and added test

* debugged test

* added offload module check and make style

* is_accelerate_available and make style

* added test decorator

* changed test model and config spec

* added offload condition

* added lazy loading for each shard

* debugged

* modified sharding

* debugged

* added traces

* removed safe serialization

* no index overload;

* trace on safe save ptrs

* added ptr condition

* debugged

* debugged ptr

* moved module map init

* remake shard only for offloaded modules

* refactored

* debugged

* refactored

* debugged

* cleaned and make style

* cleaned and make style

* added trace

* sparse module map

* debugged

* removed module map conditional

* refactored

* debug

* debugged

* added traces

* added shard mem trace

* added shard mem trace

* removed underlying storage check

* refactored

* memory leak removal and make style

* cleaned

* swapped test decs and make style

* added mem checks and make style

* added free mem warning

* implemented some suggestions

* moved onloading to accelerate

* refactored for accelerate integration

* cleaned test

* make style

* debugged offload map name

* cleaned and make style

* replaced meta device check for sharding

* cleaned and make style

* implemented some suggestions

* more suggestions

* update warning

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* more suggestions

* make style

* new make style

* Update src/transformers/modeling_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
4 people authored and zucchini-nlp committed Jun 14, 2024
1 parent e9f0ab9 commit 3ff0139
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 6 deletions.
64 changes: 58 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@
set_module_tensor_to_device,
)

accelerate_version = version.parse(importlib.metadata.version("accelerate"))
if accelerate_version >= version.parse("0.31"):
from accelerate.utils.modeling import get_state_dict_from_offload

if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
Expand Down Expand Up @@ -374,13 +378,12 @@ def shard_checkpoint(
storage_id = id_tensor_storage(weight)

# If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
if storage_id in storage_id_to_block:
if storage_id in storage_id_to_block and weight.device != torch.device("meta"):
block_id = storage_id_to_block[storage_id]
sharded_state_dicts[block_id][key] = weight
continue

weight_size = weight.numel() * dtype_byte_size(weight.dtype)

# If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
# weight in the current shard.
if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0:
Expand Down Expand Up @@ -2504,8 +2507,26 @@ def save_pretrained(
current_peft_config = self.peft_config[active_adapter]
current_peft_config.save_pretrained(save_directory)

# for offloaded modules
module_map = {}

# Save the model
if state_dict is None:
# if any model parameters are offloaded to the disk, make module map
if hasattr(self, "hf_device_map") and (
"cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values()
):
warnings.warn(
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
)
for name, module in model_to_save.named_modules():
if name == "":
continue
module_state_dict = module.state_dict()

for key in module_state_dict:
module_map[name + f".{key}"] = module

state_dict = model_to_save.state_dict()

# Translate state_dict from smp to hf if saving with smp >= 1.10
Expand All @@ -2531,12 +2552,24 @@ def save_pretrained(
# In the non-tensor case, fall back to the pointer of the object itself
ptrs[id(tensor)].append(name)

# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
error_names = []
to_delete_names = set()
# These are all the pointers of shared tensors
if hasattr(self, "hf_device_map"):
# if the model has offloaded parameters, we must check using find_tied_parameters()
tied_params = find_tied_parameters(self)
if tied_params:
tied_names = tied_params[0]
shared_ptrs = {
ptr: names for ptr, names in ptrs.items() if any(name in tied_names for name in names)
}
else:
shared_ptrs = {}
else:
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}

# Recursively descend to find tied weight keys
_tied_weights_keys = _get_tied_weight_keys(self)
error_names = []
to_delete_names = set()
for names in shared_ptrs.values():
# Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent.
Expand Down Expand Up @@ -2609,6 +2642,25 @@ def save_pretrained(

# Save the model
for shard_file, shard in shards.items():
# remake shard with onloaded parameters if necessary
if module_map:
if accelerate_version < version.parse("0.31"):
raise ImportError(
f"You need accelerate version to be greater or equal than 0.31 to save models with offloaded parameters. Detected version {accelerate_version}. "
f"Please upgrade accelerate with `pip install -U accelerate`"
)
# init state_dict for this shard
state_dict = {name: "" for name in shard}
for module_name in shard:
module = module_map[module_name]
# update state dict with onloaded parameters
state_dict = get_state_dict_from_offload(module, module_name, state_dict)

# assign shard to be the completed state dict
shard = state_dict
del state_dict
gc.collect()

if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
Expand Down
37 changes: 37 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,43 @@ def test_cached_files_are_used_when_internet_is_down(self):
# This check we did call the fake head request
mock_head.assert_called()

@require_accelerate
@mark.accelerate_tests
@require_torch_accelerator
def test_save_offloaded_model(self):
device_map = {
"transformer.wte": f"{torch_device}:0",
"transformer.wpe": f"{torch_device}:0",
"transformer.h.0": "cpu",
"transformer.h.1": "cpu",
"transformer.h.2": "cpu",
"transformer.h.3": "disk",
"transformer.h.4": "disk",
"transformer.ln_f": f"{torch_device}:0",
"lm_head": f"{torch_device}:0",
}

# check_models_equal requires onloaded tensors
model_id = "hf-internal-testing/tiny-random-gpt2"
onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu")
inputs = torch.tensor([[1, 2, 3]]).to(f"{torch_device}:0")
cpu_output = onloaded_model(inputs)[0]

with tempfile.TemporaryDirectory() as tmp_dir:
offload_folder = os.path.join(tmp_dir, "offload")
offloaded_model = AutoModelForCausalLM.from_pretrained(
model_id, device_map=device_map, offload_folder=offload_folder
)
presaved_output = offloaded_model(inputs)[0]
offloaded_model.save_pretrained(
tmp_dir, max_shard_size="200KB"
) # model is 1.6MB, max shard size is allocated to cpu by default
saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map)
postsaved_output = saved_model(inputs)[0]

self.assertTrue(torch.allclose(cpu_output, presaved_output, atol=1e-4))
self.assertTrue(torch.allclose(presaved_output, postsaved_output))

@require_safetensors
def test_use_safetensors(self):
# Should not raise anymore
Expand Down

0 comments on commit 3ff0139

Please sign in to comment.