Skip to content

Commit

Permalink
Raise LR schedule warnings only when necessary (#3207)
Browse files Browse the repository at this point in the history
* compare and exit

* compare and exit

* revert changes

* edge case ://
  • Loading branch information
snarayan21 authored and Chuck Tang committed May 16, 2024
1 parent e190323 commit b5fda07
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
43 changes: 26 additions & 17 deletions composer/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,30 +537,39 @@ def _raise_if_max_duration_exceeds_t_max(t_max: Union[str, Time], state: State):
max_dur = Time.from_timestring(max_dur)

max_dur_exceeds_t_max = False
if t_max.unit == max_dur.unit and t_max.value < max_dur.value:
max_dur_exceeds_t_max = True
elif (
t_max.unit == TimeUnit.BATCH and max_dur.unit == TimeUnit.EPOCH and state.dataloader_len is not None and
t_max.value < max_dur.value * int(state.dataloader_len)
):
max_dur_exceeds_t_max = True
elif (
t_max.unit == TimeUnit.EPOCH and max_dur.unit == TimeUnit.BATCH and state.dataloader_len is not None and
t_max.value * int(state.dataloader_len) < max_dur.value
):
max_dur_exceeds_t_max = True
elif t_max.unit != max_dur.unit:
log.info(
f'Since max_duration {max_dur} with units {max_dur.unit} and t_max {t_max} with units {t_max.unit} are not '
'comparable, make sure that your LR schedule is defined at all points in the training duration.',
)
if t_max.unit == max_dur.unit:
if t_max.value >= max_dur.value:
# Time units are comparable, and t_max is valid.
return
else:
max_dur_exceeds_t_max = True
elif (t_max.unit == TimeUnit.BATCH and max_dur.unit == TimeUnit.EPOCH and state.dataloader_len is not None):
if t_max.value >= max_dur.value * int(state.dataloader_len):
# Batches are comparable to epochs through the dataloader length, and t_max is valid.
return
else:
max_dur_exceeds_t_max = True
elif (t_max.unit == TimeUnit.EPOCH and max_dur.unit == TimeUnit.BATCH and state.dataloader_len is not None):
if t_max.value * int(state.dataloader_len) >= max_dur.value:
# Batches are comparable to epochs through the dataloader length, and t_max is valid.
return
else:
max_dur_exceeds_t_max = True

if max_dur_exceeds_t_max:
# None of the checks above passed. Time units are comparable, but t_max is invalid since it's less than max_dur.
raise ValueError(
f't_max {t_max} must be greater than or equal to max_duration {max_dur}. Otherwise, the LR schedule will '
'not be defined for the entire training duration.',
)

if t_max.unit != max_dur.unit:
# Units are not comparable, so we cannot check if t_max is valid. Log this and return.
log.debug(
f'Since max_duration {max_dur} with units {max_dur.unit} and t_max {t_max} with units {t_max.unit} are not '
'comparable, make sure that your LR schedule is defined at all points in the training duration.',
)


def _raise_if_warmup_and_max_incompatible(t_warmup: Time[int], t_max: Time[int]):
"""Checks that t_warmup and t_max have the same units.
Expand Down
1 change: 1 addition & 0 deletions tests/optim/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def test_warmup_schedulers_fail_fast(
max_duration_no_epoch = state.max_duration
if max_duration_unit == 'ep':
max_duration_no_epoch = Time.from_timestring(max_duration_unit_to_str['ba'])

error_context = contextlib.nullcontext()
if (
hasattr(scheduler, 't_max') and
Expand Down

0 comments on commit b5fda07

Please sign in to comment.