Skip to content

Commit

Permalink
Add new func
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Jul 15, 2024
1 parent 7b424fa commit 695d512
Showing 1 changed file with 31 additions and 36 deletions.
67 changes: 31 additions & 36 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,31 @@ def dtype_byte_size(dtype):
return bit_size // 8


def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
"""
Checks if `model_to_load` supports param buffer assignment (such
as when loading in empty weights) by first checking
if the model explicitly disables it, then by ensuring that the state dict keys
are a subset of the model's parameters.
"""
if len([key for key in state_dict if key.startswith(start_prefix)]) > 0:
# Some models explicitly do not support param buffer assignment
if hasattr(model_to_load, "supports_param_buffer_assignment"):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False
else:
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = list(model_to_load.state_dict().keys())[0]
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
else:
# For cases when the `state_dict` doesn't have any real weights (`albert`)
return False
return False


def shard_checkpoint(
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
):
Expand Down Expand Up @@ -4254,25 +4279,10 @@ def _find_mismatched_keys(
unexpected_keys=unexpected_keys,
)
else:
assign_to_params_buffers = False
# Sharded checkpoint or whole but low_cpu_mem_usage==True
if len([key for key in state_dict if key.startswith(start_prefix)]) > 0:
# Some models do not support param buffer assignment
if hasattr(model_to_load, "supports_param_buffer_assignment"):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
elif all(start_prefix + k in state_dict for k in model_to_load.state_dict().keys()):
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype and have all their keys
first_key = list(model_to_load.state_dict().keys())[0]
if start_prefix + first_key in state_dict:
assign_to_params_buffers = (
state_dict[start_prefix + first_key].dtype
== model_to_load.state_dict()[first_key].dtype
)
else:
# For cases when the `state_dict` doesn't have any real weights (`albert`)
assign_to_params_buffers = False
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs = _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
)
Expand Down Expand Up @@ -4349,24 +4359,9 @@ def _find_mismatched_keys(
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
if assign_to_params_buffers is None:
if len([key for key in state_dict if key.startswith(start_prefix)]) > 0:
# Some models do not support param buffer assignment
if hasattr(model_to_load, "supports_param_buffer_assignment"):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
assign_to_params_buffers = False
else:
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = list(model_to_load.state_dict().keys())[0]
if start_prefix + first_key in state_dict:
assign_to_params_buffers = (
state_dict[start_prefix + first_key].dtype
== model_to_load.state_dict()[first_key].dtype
)
else:
# For cases when the `state_dict` doesn't have any real weights (`albert`)
assign_to_params_buffers = False
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs += _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
)
Expand Down

0 comments on commit 695d512

Please sign in to comment.