From 695d5120e0ffb6aea20682e650afa61ba4f78b7d Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Mon, 15 Jul 2024 16:05:55 -0400 Subject: [PATCH] Add new func --- src/transformers/modeling_utils.py | 67 ++++++++++++++---------------- 1 file changed, 31 insertions(+), 36 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 997f295e812dd9..330ae0555d1993 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -338,6 +338,31 @@ def dtype_byte_size(dtype): return bit_size // 8 +def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): + """ + Checks if `model_to_load` supports param buffer assignment (such + as when loading in empty weights) by first checking + if the model explicitly disables it, then by ensuring that the state dict keys + are a subset of the model's parameters. + """ + if len([key for key in state_dict if key.startswith(start_prefix)]) > 0: + # Some models explicitly do not support param buffer assignment + if hasattr(model_to_load, "supports_param_buffer_assignment"): + logger.debug( + f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" + ) + return False + else: + # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype + first_key = list(model_to_load.state_dict().keys())[0] + if start_prefix + first_key in state_dict: + return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype + else: + # For cases when the `state_dict` doesn't have any real weights (`albert`) + return False + return False + + def shard_checkpoint( state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME ): @@ -4254,25 +4279,10 @@ def _find_mismatched_keys( unexpected_keys=unexpected_keys, ) else: - assign_to_params_buffers = False # Sharded checkpoint or whole but low_cpu_mem_usage==True - if len([key for key in state_dict if key.startswith(start_prefix)]) > 0: - # Some models do not support param buffer assignment - if hasattr(model_to_load, "supports_param_buffer_assignment"): - logger.debug( - f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" - ) - elif all(start_prefix + k in state_dict for k in model_to_load.state_dict().keys()): - # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype and have all their keys - first_key = list(model_to_load.state_dict().keys())[0] - if start_prefix + first_key in state_dict: - assign_to_params_buffers = ( - state_dict[start_prefix + first_key].dtype - == model_to_load.state_dict()[first_key].dtype - ) - else: - # For cases when the `state_dict` doesn't have any real weights (`albert`) - assign_to_params_buffers = False + assign_to_params_buffers = check_support_param_buffer_assignment( + model_to_load, state_dict, start_prefix + ) error_msgs = _load_state_dict_into_model( model_to_load, state_dict, start_prefix, assign_to_params_buffers ) @@ -4349,24 +4359,9 @@ def _find_mismatched_keys( else: # Sharded checkpoint or whole but low_cpu_mem_usage==True if assign_to_params_buffers is None: - if len([key for key in state_dict if key.startswith(start_prefix)]) > 0: - # Some models do not support param buffer assignment - if hasattr(model_to_load, "supports_param_buffer_assignment"): - logger.debug( - f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" - ) - assign_to_params_buffers = False - else: - # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype - first_key = list(model_to_load.state_dict().keys())[0] - if start_prefix + first_key in state_dict: - assign_to_params_buffers = ( - state_dict[start_prefix + first_key].dtype - == model_to_load.state_dict()[first_key].dtype - ) - else: - # For cases when the `state_dict` doesn't have any real weights (`albert`) - assign_to_params_buffers = False + assign_to_params_buffers = check_support_param_buffer_assignment( + model_to_load, state_dict, start_prefix + ) error_msgs += _load_state_dict_into_model( model_to_load, state_dict, start_prefix, assign_to_params_buffers )