Skip to content

Commit

Permalink
Reload big model with multiple state dict files (#1644)
Browse files Browse the repository at this point in the history
* Reload big model with multiple state dict files

* Add description for reload func
  • Loading branch information
natuan committed Jul 3, 2023
1 parent a61f89a commit b73a173
Showing 1 changed file with 28 additions and 17 deletions.
45 changes: 28 additions & 17 deletions src/sparseml/transformers/sparsification/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torch.nn import Module
from transformers import Trainer as HFTransformersTrainer
from transformers import TrainerCallback, TrainerControl, TrainingArguments
from transformers.file_utils import WEIGHTS_NAME, PaddingStrategy
from transformers.file_utils import PaddingStrategy
from transformers.integrations import TensorBoardCallback
from transformers.trainer_callback import TrainerState
from transformers.trainer_pt_utils import reissue_pt_warnings
Expand Down Expand Up @@ -218,12 +218,13 @@ def apply_manager(self, epoch: float, checkpoint: Optional[str]) -> bool:

# reload the state dict for the model now that architecture matches expected
load_path = checkpoint or self.model_state_path
self._reload_model_state(load_path, orig_state_dict)
if self._reload_model_state(load_path, orig_state_dict):
_LOGGER.info(
"Reloaded model state after SparseML recipe structure modifications "
f"from {load_path}"
)

self.manager_applied = True
_LOGGER.info(
"Reloaded model state after SparseML recipe structure modifications "
f"from {load_path}"
)

return True

Expand Down Expand Up @@ -652,27 +653,36 @@ def _setup_manager(
return manager, arch_manager

def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
if (
not load_path
or not os.path.isdir(load_path)
or not os.path.isfile(os.path.join(load_path, WEIGHTS_NAME))
):
"""
Reload the weights after model arch changes due to recipe application
Return True if weights are successfully reloaded; False otherwise
"""
invalid_load_path = not load_path or not os.path.isdir(load_path)
files = os.listdir(load_path) if not invalid_load_path else []
weight_files = [
os.path.join(load_path, f)
for f in files
if f.startswith("pytorch_model") and f.endswith("bin")
]
if not weight_files:
_LOGGER.warning(
"Model state was not reloaded for SparseML: "
f"could not find model weights for model_path {load_path}"
f"could not find model weights for {load_path}"
)
return
return False

current_state_dict = self.model.state_dict()

if set(orig_state_dict.keys()) == set(current_state_dict):
# no change in keys, ignore reload
return
return False

# change in keys due to architecture changes, reload statedict
loaded_state_dict = torch.load(
os.path.join(load_path, WEIGHTS_NAME), map_location="cpu"
)
loaded_state_dict = {}
for f in weight_files:
dd = torch.load(os.path.join(load_path, f), map_location="cpu")
loaded_state_dict.update(dd)

_, missing, unexpected, _, _ = self.model._load_pretrained_model(
model=self.model,
state_dict=loaded_state_dict,
Expand Down Expand Up @@ -704,6 +714,7 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
model_type="student" if self.teacher else "model",
delayed_load=False,
)
return True

def _data_loader_builder(self, kwargs: Optional[Dict[str, Any]] = None):
default_loader = self.get_train_dataloader()
Expand Down

0 comments on commit b73a173

Please sign in to comment.