Skip to content

Commit

Permalink
Fix learning rate gap on resume (#9468)
Browse files Browse the repository at this point in the history
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: EunChan Kim <eunchan@hanyang.ac.kr>
Co-authored-by: Lakshantha Dissanayake <lakshanthad@yahoo.com>
Co-authored-by: RizwanMunawar <chr043416@gmail.com>
Co-authored-by: gs80140 <gs80140@users.noreply.github.com>
  • Loading branch information
7 people committed Apr 2, 2024
1 parent e5f4f5c commit 1e547e6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
19 changes: 10 additions & 9 deletions ultralytics/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ def _do_train(self, world_size=1):
while True:
self.epoch = epoch
self.run_callbacks("on_train_epoch_start")
with warnings.catch_warnings():
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
self.scheduler.step()

self.model.train()
if RANK != -1:
self.train_loader.sampler.set_epoch(epoch)
Expand Down Expand Up @@ -426,15 +430,12 @@ def _do_train(self, world_size=1):
t = time.time()
self.epoch_time = t - self.epoch_time_start
self.epoch_time_start = t
with warnings.catch_warnings():
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
if self.args.time:
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
self._setup_scheduler()
self.scheduler.last_epoch = self.epoch # do not move
self.stop |= epoch >= self.epochs # stop if exceeded epochs
self.scheduler.step()
if self.args.time:
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
self._setup_scheduler()
self.scheduler.last_epoch = self.epoch # do not move
self.stop |= epoch >= self.epochs # stop if exceeded epochs
self.run_callbacks("on_fit_epoch_end")
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors

Expand Down
5 changes: 3 additions & 2 deletions ultralytics/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.nn.functional as F
import torchvision

from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, __version__
from ultralytics.utils.checks import PYTHON_VERSION, check_version

try:
Expand Down Expand Up @@ -614,8 +614,9 @@ def __call__(self, epoch, fitness):
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
stop = delta >= self.patience # stop training if patience exceeded
if stop:
prefix = colorstr("EarlyStopping: ")
LOGGER.info(
f"Stopping training early as no improvement observed in last {self.patience} epochs. "
f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
Expand Down

0 comments on commit 1e547e6

Please sign in to comment.