Skip to content

Commit

Permalink
Save optimizer as FP16 for smaller checkpoints (ultralytics#9435)
Browse files Browse the repository at this point in the history
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
  • Loading branch information
2 people authored and hmurari committed Apr 17, 2024
1 parent 37fbb06 commit aac5dc2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
4 changes: 4 additions & 0 deletions docs/en/reference/utils/torch_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ keywords: Ultralytics, Torch Utils, Model EMA, Early Stopping, Smart Inference,

<br><br>

## ::: ultralytics.utils.torch_utils.convert_optimizer_state_dict_to_fp16

<br><br>

## ::: ultralytics.utils.torch_utils.profile

<br><br>
3 changes: 2 additions & 1 deletion ultralytics/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ultralytics.utils.torch_utils import (
EarlyStopping,
ModelEMA,
convert_optimizer_state_dict_to_fp16,
init_seeds,
one_cycle,
select_device,
Expand Down Expand Up @@ -488,7 +489,7 @@ def save_model(self):
"model": None, # resume and final checkpoints derive from EMA
"ema": deepcopy(self.ema.ema).half(),
"updates": self.ema.updates,
"optimizer": self.optimizer.state_dict(),
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
"train_args": vars(self.args), # save as dict
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
"train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
Expand Down
14 changes: 14 additions & 0 deletions ultralytics/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,20 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")


def convert_optimizer_state_dict_to_fp16(state_dict):
"""
Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
This method aims to reduce storage size without altering 'param_groups' as they contain non-tensor data.
"""
for state in state_dict["state"].values():
for k, v in state.items():
if isinstance(v, torch.Tensor) and v.dtype is torch.float32:
state[k] = v.half()

return state_dict


def profile(input, ops, n=10, device=None):
"""
Ultralytics speed, memory and FLOPs profiler.
Expand Down

0 comments on commit aac5dc2

Please sign in to comment.