Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise LR schedule warnings only when necessary #3207

Merged
merged 4 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions composer/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,30 +537,39 @@ def _raise_if_max_duration_exceeds_t_max(t_max: Union[str, Time], state: State):
max_dur = Time.from_timestring(max_dur)

max_dur_exceeds_t_max = False
if t_max.unit == max_dur.unit and t_max.value < max_dur.value:
max_dur_exceeds_t_max = True
elif (
t_max.unit == TimeUnit.BATCH and max_dur.unit == TimeUnit.EPOCH and state.dataloader_len is not None and
t_max.value < max_dur.value * int(state.dataloader_len)
):
max_dur_exceeds_t_max = True
elif (
t_max.unit == TimeUnit.EPOCH and max_dur.unit == TimeUnit.BATCH and state.dataloader_len is not None and
t_max.value * int(state.dataloader_len) < max_dur.value
):
max_dur_exceeds_t_max = True
elif t_max.unit != max_dur.unit:
log.info(
f'Since max_duration {max_dur} with units {max_dur.unit} and t_max {t_max} with units {t_max.unit} are not '
'comparable, make sure that your LR schedule is defined at all points in the training duration.',
)
if t_max.unit == max_dur.unit:
if t_max.value >= max_dur.value:
# Time units are comparable, and t_max is valid.
return
else:
max_dur_exceeds_t_max = True
elif (t_max.unit == TimeUnit.BATCH and max_dur.unit == TimeUnit.EPOCH and state.dataloader_len is not None):
if t_max.value >= max_dur.value * int(state.dataloader_len):
# Batches are comparable to epochs through the dataloader length, and t_max is valid.
return
else:
max_dur_exceeds_t_max = True
elif (t_max.unit == TimeUnit.EPOCH and max_dur.unit == TimeUnit.BATCH and state.dataloader_len is not None):
if t_max.value * int(state.dataloader_len) >= max_dur.value:
# Batches are comparable to epochs through the dataloader length, and t_max is valid.
return
else:
max_dur_exceeds_t_max = True

if max_dur_exceeds_t_max:
# None of the checks above passed. Time units are comparable, but t_max is invalid since it's less than max_dur.
raise ValueError(
f't_max {t_max} must be greater than or equal to max_duration {max_dur}. Otherwise, the LR schedule will '
'not be defined for the entire training duration.',
)

if t_max.unit != max_dur.unit:
# Units are not comparable, so we cannot check if t_max is valid. Log this and return.
log.debug(
f'Since max_duration {max_dur} with units {max_dur.unit} and t_max {t_max} with units {t_max.unit} are not '
'comparable, make sure that your LR schedule is defined at all points in the training duration.',
)

Comment on lines +566 to +572
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this log statement can be very repetitive as it emanates from the __call__. Should this be part of the __init__ instead ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Antoine -- the problem described in the issue is less the warning level and more the frequency of logging

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the log was still an error because the durations were equal through the dataloader length even the units were different. The current refactoring solve this with the early returns.

Now that I give it a second though, I'm not sure my proposal of handling it in the __init__ makes sense because the max_duration can change if the .fit() if called multiple times.


def _raise_if_warmup_and_max_incompatible(t_warmup: Time[int], t_max: Time[int]):
"""Checks that t_warmup and t_max have the same units.
Expand Down
1 change: 1 addition & 0 deletions tests/optim/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def test_warmup_schedulers_fail_fast(
max_duration_no_epoch = state.max_duration
if max_duration_unit == 'ep':
max_duration_no_epoch = Time.from_timestring(max_duration_unit_to_str['ba'])

error_context = contextlib.nullcontext()
if (
hasattr(scheduler, 't_max') and
Expand Down
Loading