-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Changed LearningRateLogger to LearningRateMonitor (#3251)
* Change LearningRateLogger to LearningRateMonitor * file rename * docs * add LearningRateLogger with deprecation warning * deprecated LearningRateLogger * move deprecation check * chlog Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
- Loading branch information
Showing
7 changed files
with
224 additions
and
184 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,20 @@ | ||
from pytorch_lightning.callbacks.base import Callback | ||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping | ||
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor | ||
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler | ||
from pytorch_lightning.callbacks.lr_logger import LearningRateLogger | ||
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor | ||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint | ||
from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar | ||
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor | ||
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase | ||
|
||
__all__ = [ | ||
'Callback', | ||
'EarlyStopping', | ||
'ModelCheckpoint', | ||
'GPUStatsMonitor', | ||
'GradientAccumulationScheduler', | ||
'LearningRateLogger', | ||
'ProgressBarBase', | ||
'LearningRateMonitor', | ||
'ModelCheckpoint', | ||
'ProgressBar', | ||
'GPUStatsMonitor' | ||
'ProgressBarBase', | ||
] |
159 changes: 6 additions & 153 deletions
159
pytorch_lightning/callbacks/lr_logger.py
100755 → 100644
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,156 +1,9 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
r""" | ||
Learning Rate Logger | ||
==================== | ||
Log learning rate for lr schedulers during training | ||
""" | ||
|
||
from typing import Optional | ||
|
||
from pytorch_lightning.callbacks.base import Callback | ||
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor | ||
from pytorch_lightning.utilities import rank_zero_warn | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
|
||
class LearningRateLogger(Callback): | ||
r""" | ||
Automatically logs learning rate for learning rate schedulers during training. | ||
Args: | ||
logging_interval: set to `epoch` or `step` to log `lr` of all optimizers | ||
at the same interval, set to `None` to log at individual interval | ||
according to the `interval` key of each scheduler. Defaults to ``None``. | ||
Example:: | ||
>>> from pytorch_lightning import Trainer | ||
>>> from pytorch_lightning.callbacks import LearningRateLogger | ||
>>> lr_logger = LearningRateLogger(logging_interval='step') | ||
>>> trainer = Trainer(callbacks=[lr_logger]) | ||
Logging names are automatically determined based on optimizer class name. | ||
In case of multiple optimizers of same type, they will be named `Adam`, | ||
`Adam-1` etc. If a optimizer has multiple parameter groups they will | ||
be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a | ||
`name` keyword in the construction of the learning rate schdulers | ||
Example:: | ||
def configure_optimizer(self): | ||
optimizer = torch.optim.Adam(...) | ||
lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...) | ||
'name': 'my_logging_name'} | ||
return [optimizer], [lr_scheduler] | ||
""" | ||
def __init__(self, logging_interval: Optional[str] = None): | ||
if logging_interval not in (None, 'step', 'epoch'): | ||
raise MisconfigurationException( | ||
'logging_interval should be `step` or `epoch` or `None`.' | ||
) | ||
|
||
self.logging_interval = logging_interval | ||
self.lrs = None | ||
self.lr_sch_names = [] | ||
|
||
def on_train_start(self, trainer, pl_module): | ||
""" Called before training, determines unique names for all lr | ||
schedulers in the case of multiple of the same type or in | ||
the case of multiple parameter groups | ||
""" | ||
if not trainer.logger: | ||
raise MisconfigurationException( | ||
'Cannot use LearningRateLogger callback with Trainer that has no logger.' | ||
) | ||
|
||
if not trainer.lr_schedulers: | ||
rank_zero_warn( | ||
'You are using LearningRateLogger callback with models that' | ||
' have no learning rate schedulers. Please see documentation' | ||
' for `configure_optimizers` method.', RuntimeWarning | ||
) | ||
|
||
# Find names for schedulers | ||
names = self._find_names(trainer.lr_schedulers) | ||
|
||
# Initialize for storing values | ||
self.lrs = {name: [] for name in names} | ||
|
||
def on_batch_start(self, trainer, pl_module): | ||
if self.logging_interval != 'epoch': | ||
interval = 'step' if self.logging_interval is None else 'any' | ||
latest_stat = self._extract_lr(trainer, interval) | ||
|
||
if trainer.logger is not None and latest_stat: | ||
trainer.logger.log_metrics(latest_stat, step=trainer.global_step) | ||
|
||
def on_epoch_start(self, trainer, pl_module): | ||
if self.logging_interval != 'step': | ||
interval = 'epoch' if self.logging_interval is None else 'any' | ||
latest_stat = self._extract_lr(trainer, interval) | ||
|
||
if trainer.logger is not None and latest_stat: | ||
trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch) | ||
|
||
def _extract_lr(self, trainer, interval): | ||
latest_stat = {} | ||
|
||
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): | ||
if scheduler['interval'] == interval or interval == 'any': | ||
param_groups = scheduler['scheduler'].optimizer.param_groups | ||
if len(param_groups) != 1: | ||
for i, pg in enumerate(param_groups): | ||
lr, key = pg['lr'], f'{name}/pg{i + 1}' | ||
self.lrs[key].append(lr) | ||
latest_stat[key] = lr | ||
else: | ||
self.lrs[name].append(param_groups[0]['lr']) | ||
latest_stat[name] = param_groups[0]['lr'] | ||
|
||
return latest_stat | ||
|
||
def _find_names(self, lr_schedulers): | ||
# Create uniqe names in the case we have multiple of the same learning | ||
# rate schduler + multiple parameter groups | ||
names = [] | ||
for scheduler in lr_schedulers: | ||
sch = scheduler['scheduler'] | ||
if 'name' in scheduler: | ||
name = scheduler['name'] | ||
else: | ||
opt_name = 'lr-' + sch.optimizer.__class__.__name__ | ||
i, name = 1, opt_name | ||
# Multiple schduler of the same type | ||
while True: | ||
if name not in names: | ||
break | ||
i, name = i + 1, f'{opt_name}-{i}' | ||
|
||
# Multiple param groups for the same schduler | ||
param_groups = sch.optimizer.param_groups | ||
|
||
if len(param_groups) != 1: | ||
for i, pg in enumerate(param_groups): | ||
temp = f'{name}/pg{i + 1}' | ||
names.append(temp) | ||
else: | ||
names.append(name) | ||
|
||
self.lr_sch_names.append(name) | ||
|
||
return names | ||
class LearningRateLogger(LearningRateMonitor): | ||
def __init__(self, *args, **kwargs): | ||
rank_zero_warn("`LearningRateLogger` is now `LearningRateMonitor`" | ||
" and this will be removed in v0.11.0", DeprecationWarning) | ||
super().__init__(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
r""" | ||
Learning Rate Monitor | ||
===================== | ||
Monitor and logs learning rate for lr schedulers during training. | ||
""" | ||
|
||
from typing import Optional | ||
|
||
from pytorch_lightning.callbacks.base import Callback | ||
from pytorch_lightning.utilities import rank_zero_warn | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
|
||
class LearningRateMonitor(Callback): | ||
r""" | ||
Automatically monitor and logs learning rate for learning rate schedulers during training. | ||
Args: | ||
logging_interval: set to `epoch` or `step` to log `lr` of all optimizers | ||
at the same interval, set to `None` to log at individual interval | ||
according to the `interval` key of each scheduler. Defaults to ``None``. | ||
Example:: | ||
>>> from pytorch_lightning import Trainer | ||
>>> from pytorch_lightning.callbacks import LearningRateMonitor | ||
>>> lr_monitor = LearningRateMonitor(logging_interval='step') | ||
>>> trainer = Trainer(callbacks=[lr_monitor]) | ||
Logging names are automatically determined based on optimizer class name. | ||
In case of multiple optimizers of same type, they will be named `Adam`, | ||
`Adam-1` etc. If a optimizer has multiple parameter groups they will | ||
be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a | ||
`name` keyword in the construction of the learning rate schdulers | ||
Example:: | ||
def configure_optimizer(self): | ||
optimizer = torch.optim.Adam(...) | ||
lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...) | ||
'name': 'my_logging_name'} | ||
return [optimizer], [lr_scheduler] | ||
""" | ||
def __init__(self, logging_interval: Optional[str] = None): | ||
if logging_interval not in (None, 'step', 'epoch'): | ||
raise MisconfigurationException( | ||
'logging_interval should be `step` or `epoch` or `None`.' | ||
) | ||
|
||
self.logging_interval = logging_interval | ||
self.lrs = None | ||
self.lr_sch_names = [] | ||
|
||
def on_train_start(self, trainer, pl_module): | ||
""" | ||
Called before training, determines unique names for all lr | ||
schedulers in the case of multiple of the same type or in | ||
the case of multiple parameter groups | ||
""" | ||
if not trainer.logger: | ||
raise MisconfigurationException( | ||
'Cannot use LearningRateMonitor callback with Trainer that has no logger.' | ||
) | ||
|
||
if not trainer.lr_schedulers: | ||
rank_zero_warn( | ||
'You are using LearningRateMonitor callback with models that' | ||
' have no learning rate schedulers. Please see documentation' | ||
' for `configure_optimizers` method.', RuntimeWarning | ||
) | ||
|
||
# Find names for schedulers | ||
names = self._find_names(trainer.lr_schedulers) | ||
|
||
# Initialize for storing values | ||
self.lrs = {name: [] for name in names} | ||
|
||
def on_batch_start(self, trainer, pl_module): | ||
if self.logging_interval != 'epoch': | ||
interval = 'step' if self.logging_interval is None else 'any' | ||
latest_stat = self._extract_lr(trainer, interval) | ||
|
||
if trainer.logger is not None and latest_stat: | ||
trainer.logger.log_metrics(latest_stat, step=trainer.global_step) | ||
|
||
def on_epoch_start(self, trainer, pl_module): | ||
if self.logging_interval != 'step': | ||
interval = 'epoch' if self.logging_interval is None else 'any' | ||
latest_stat = self._extract_lr(trainer, interval) | ||
|
||
if trainer.logger is not None and latest_stat: | ||
trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch) | ||
|
||
def _extract_lr(self, trainer, interval): | ||
latest_stat = {} | ||
|
||
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): | ||
if scheduler['interval'] == interval or interval == 'any': | ||
param_groups = scheduler['scheduler'].optimizer.param_groups | ||
if len(param_groups) != 1: | ||
for i, pg in enumerate(param_groups): | ||
lr, key = pg['lr'], f'{name}/pg{i + 1}' | ||
self.lrs[key].append(lr) | ||
latest_stat[key] = lr | ||
else: | ||
self.lrs[name].append(param_groups[0]['lr']) | ||
latest_stat[name] = param_groups[0]['lr'] | ||
|
||
return latest_stat | ||
|
||
def _find_names(self, lr_schedulers): | ||
# Create uniqe names in the case we have multiple of the same learning | ||
# rate schduler + multiple parameter groups | ||
names = [] | ||
for scheduler in lr_schedulers: | ||
sch = scheduler['scheduler'] | ||
if 'name' in scheduler: | ||
name = scheduler['name'] | ||
else: | ||
opt_name = 'lr-' + sch.optimizer.__class__.__name__ | ||
i, name = 1, opt_name | ||
|
||
# Multiple schduler of the same type | ||
while True: | ||
if name not in names: | ||
break | ||
i, name = i + 1, f'{opt_name}-{i}' | ||
|
||
# Multiple param groups for the same schduler | ||
param_groups = sch.optimizer.param_groups | ||
|
||
if len(param_groups) != 1: | ||
for i, pg in enumerate(param_groups): | ||
temp = f'{name}/pg{i + 1}' | ||
names.append(temp) | ||
else: | ||
names.append(name) | ||
|
||
self.lr_sch_names.append(name) | ||
|
||
return names |
Oops, something went wrong.