From 53aa5636cf169572bc05329ca41663b50d34f214 Mon Sep 17 00:00:00 2001 From: Oliver Neumann Date: Thu, 30 Apr 2020 13:54:50 +0200 Subject: [PATCH 1/7] Fixed broken link in PR template (#1675) * Fixed broken link in PR template. * Updated CHANGELOG.md --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- CHANGELOG.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index af80acf6a6390d..0bda363228b1c0 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -4,7 +4,7 @@ - [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/.github/CONTRIBUTING.md), Pull Request section? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? -- [ ] If you made a notable change (that affects users), did you update the [CHANGELOG](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/.github/CHANGELOG.md)? +- [ ] If you made a notable change (that affects users), did you update the [CHANGELOG](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/CHANGELOG.md)? diff --git a/CHANGELOG.md b/CHANGELOG.md index 9750c6566b769e..85edc73864efd4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed broken link in PR template ([#1675](https://github.com/PyTorchLightning/pytorch-lightning/pull/1675)) - Fixed ModelCheckpoint not None checking filepath ([1654](https://github.com/PyTorchLightning/pytorch-lightning/pull/1654)) From 8d564b5e38d1a1f820304a27f2d615d8bd4f401d Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Thu, 30 Apr 2020 07:57:24 -0400 Subject: [PATCH 2/7] call on_load_checkpoint() when resuming from checkpoint (#1666) --- CHANGELOG.md | 1 + pytorch_lightning/trainer/training_io.py | 4 ++++ tests/trainer/test_trainer.py | 15 +++++++++++---- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 85edc73864efd4..10ec061f18b2bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed broken link in PR template ([#1675](https://github.com/PyTorchLightning/pytorch-lightning/pull/1675)) - Fixed ModelCheckpoint not None checking filepath ([1654](https://github.com/PyTorchLightning/pytorch-lightning/pull/1654)) +- Trainer now calls `on_load_checkpoint()` when resuming from a checkpoint ([1666](https://github.com/PyTorchLightning/pytorch-lightning/pull/1666)) ## [0.7.5] - 2020-04-27 diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 82bc0829aa238e..393d6540398b7e 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -278,6 +278,10 @@ def restore(self, checkpoint_path: str, on_gpu: bool): # load the state_dict on the model automatically model.load_state_dict(checkpoint['state_dict']) + + # give model a chance to load something + model.on_load_checkpoint(checkpoint) + if on_gpu: model.cuda(self.root_gpu) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 18cc2586a0b9f6..cb650fd87e4c4d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -309,8 +309,8 @@ def test_model_freeze_unfreeze(): model.unfreeze() -def test_resume_from_checkpoint_epoch_restored(tmpdir): - """Verify resuming from checkpoint runs the right number of epochs""" +def test_resume_from_checkpoint(tmpdir): + """Verify resuming from checkpoint (epoch, batch numbers and on_load_checkpoint())""" import types tutils.reset_seed() @@ -322,6 +322,7 @@ def _new_model(): model = LightningTestModel(hparams) model.num_epochs_seen = 0 model.num_batches_seen = 0 + model.num_on_load_checkpoint_called = 0 def increment_epoch(self): self.num_epochs_seen += 1 @@ -329,10 +330,14 @@ def increment_epoch(self): def increment_batch(self, _): self.num_batches_seen += 1 - # Bind the increment_epoch function on_epoch_end so that the - # model keeps track of the number of epochs it has seen. + def increment_on_load_checkpoint(self, _): + self.num_on_load_checkpoint_called += 1 + + # Bind methods to keep track of epoch numbers, batch numbers it has seen + # as well as number of times it has called on_load_checkpoint() model.on_epoch_end = types.MethodType(increment_epoch, model) model.on_batch_start = types.MethodType(increment_batch, model) + model.on_load_checkpoint = types.MethodType(increment_on_load_checkpoint, model) return model model = _new_model() @@ -356,6 +361,7 @@ def increment_batch(self, _): assert model.num_epochs_seen == 2 assert model.num_batches_seen == training_batches * 2 + assert model.num_on_load_checkpoint_called == 0 # Other checkpoints can be uncommented if/when resuming mid-epoch is supported checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt'))) @@ -369,6 +375,7 @@ def increment_batch(self, _): new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check) new_trainer.fit(next_model) assert state['global_step'] + next_model.num_batches_seen == training_batches * trainer_options['max_epochs'] + assert next_model.num_on_load_checkpoint_called == 1 def _init_steps_model(): From f9c9e39ab87e393c157a30aa659be71eef11190e Mon Sep 17 00:00:00 2001 From: Jacob Zhong Date: Thu, 30 Apr 2020 07:58:03 -0400 Subject: [PATCH 3/7] Add log output for slurm (#1657) * add log output for slurm * change log levels * formatting Co-authored-by: Jirka Borovec --- pytorch_lightning/core/lightning.py | 2 ++ pytorch_lightning/trainer/distrib_data_parallel.py | 6 +++++- pytorch_lightning/trainer/training_io.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a1f3eb4e9252c3..26016613c63695 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -930,10 +930,12 @@ def init_ddp_connection( if 'MASTER_ADDR' not in os.environ: log.warning("MASTER_ADDR environment variable is not defined. Set as localhost") os.environ['MASTER_ADDR'] = '127.0.0.1' + log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") if 'MASTER_PORT' not in os.environ: log.warning("MASTER_PORT environment variable is not defined. Set as 12910") os.environ['MASTER_PORT'] = '12910' + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != world_size: log.warning("WORLD_SIZE environment variable is not equal to the computed " diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 659aa7a072f9a7..56c7bae8ec6a71 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -277,6 +277,10 @@ def configure_slurm_ddp(self, num_gpu_nodes): except Exception as e: pass + # notify user the that slurm is managing tasks + if self.is_slurm_managing_tasks: + log.info('Multi-processing is handled by Slurm.') + def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids): if data_parallel_device_ids is None: return @@ -293,7 +297,7 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids): gpu_str = ','.join([str(x) for x in data_parallel_device_ids]) os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str - log.info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]') + log.debug(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]') def ddp_train(self, process_idx, model): """ diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 393d6540398b7e..e49329538fcf52 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -215,7 +215,7 @@ def sig_handler(self, signum, frame): # pragma: no-cover if result == 0: log.info(f'requeued exp {job_id}') else: - log.info('requeue failed...') + log.warning('requeue failed...') # close experiment to avoid issues self.logger.close() From 2ec8d61e94722f7ecc97e1add72f4ac693d2f612 Mon Sep 17 00:00:00 2001 From: weipengOO98 <63845580+weipengOO98@users.noreply.github.com> Date: Thu, 30 Apr 2020 19:58:42 +0800 Subject: [PATCH 4/7] Update new-project.rst (#1655) fix a typo --- docs/source/new-project.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/new-project.rst b/docs/source/new-project.rst index 7d81ba44a352f7..e3f3a892d983fa 100644 --- a/docs/source/new-project.rst +++ b/docs/source/new-project.rst @@ -100,7 +100,7 @@ To also add a validation loop add the following functions def validation_epoch_end(self, outputs): avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() tensorboard_logs = {'val_loss': avg_loss} - return {'val_loss': avg_loss, 'log': tensorboard_logs + return {'val_loss': avg_loss, 'log': tensorboard_logs} def val_dataloader(self): # TODO: do a real train/val split From d40425d2574c5698eed350e340d9ece779a68ac2 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 30 Apr 2020 08:04:18 -0400 Subject: [PATCH 5/7] added warning to crash (#1625) * added warning to crash * formatting Co-authored-by: J. Borovec --- pytorch_lightning/core/lightning.py | 6 +++--- pytorch_lightning/trainer/training_io.py | 8 ++++++-- setup.cfg | 1 + 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 26016613c63695..fc88fb8c786871 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1162,9 +1162,9 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, # native amp + lbfgs is a no go right now if self.trainer.use_amp and self.trainer.use_native_amp: - m = 'native PyTorch amp and lbfgs are not compatible. To request, please file' \ - 'a Github issue in PyTorch and tag @mcarilli' - raise MisconfigurationException(m) + raise MisconfigurationException( + 'native PyTorch amp and lbfgs are not compatible.' + ' To request, please file a Github issue in PyTorch and tag @mcarilli') optimizer.step(second_order_closure) else: if self.trainer.use_amp and self.trainer.use_native_amp: diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index e49329538fcf52..78d24fad0a18f2 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -251,9 +251,11 @@ def save_checkpoint(self, filepath): # do the actual save try: self._atomic_save(checkpoint, filepath) - except AttributeError: + except AttributeError as e: if 'hparams' in checkpoint: del checkpoint['hparams'] + rank_zero_warn('warning, `hparams` dropped from checkpoint.' + f' An attribute is not picklable {e}') self._atomic_save(checkpoint, filepath) @@ -434,9 +436,11 @@ def hpc_save(self, folderpath: str, logger): # TODO: fix for anything with multiprocess DP, DDP, DDP2 try: self._atomic_save(checkpoint, filepath) - except AttributeError: + except AttributeError as e: if 'hparams' in checkpoint: del checkpoint['hparams'] + rank_zero_warn('warning, `hparams` dropped from checkpoint.' + f' An attribute is not picklable {e}') self._atomic_save(checkpoint, filepath) diff --git a/setup.cfg b/setup.cfg index 2f1b55c1894b4e..aab7a580c77b91 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,6 +18,7 @@ exclude_lines = pragma: no-cover warnings pass + rank_zero_warn [flake8] # TODO: this should be 88 or 100 according PEP8 From 3eac6cfd4fbbc4d13f4e93f6d90f8ee5302c421e Mon Sep 17 00:00:00 2001 From: Nathan Breitsch Date: Thu, 30 Apr 2020 08:04:50 -0400 Subject: [PATCH 6/7] Don't convert namedtuple to tuple (#1589) * Don't convert namedtuple to tuple * Test namedtuples sent to device correctly --- pytorch_lightning/trainer/distrib_parts.py | 13 +++++++++---- tests/models/test_cpu.py | 8 ++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 73efaf67c486bb..db4e132c0b4457 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -461,10 +461,15 @@ def __transfer_data_to_device(self, batch, device, gpu_id=None): # when tuple if isinstance(batch, tuple): - batch = list(batch) - for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device, gpu_id) - return tuple(batch) + # when namedtuple + if hasattr(batch, '_fields'): + elem_type = type(batch) + return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch)) + else: + batch = list(batch) + for i, x in enumerate(batch): + batch[i] = self.__transfer_data_to_device(x, device, gpu_id) + return tuple(batch) # when dict if isinstance(batch, dict): diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 612286404041b6..eb3b28769e2063 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -1,3 +1,4 @@ +from collections import namedtuple import platform import pytest @@ -221,6 +222,13 @@ def test_single_gpu_batch_parse(): assert batch[1][0]['b'].device.index == 0 assert batch[1][0]['b'].type() == 'torch.cuda.FloatTensor' + # namedtuple of tensor + BatchType = namedtuple('BatchType', ['a', 'b']) + batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)] + batch = trainer.transfer_batch_to_gpu(batch, 0) + assert batch[0].a.device.index == 0 + assert batch[0].a.type() == 'torch.cuda.FloatTensor' + def test_simple_cpu(tmpdir): """Verify continue training session on CPU.""" From 142bc0230e228cd2e851481e5a07069e7d198655 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 30 Apr 2020 14:06:41 +0200 Subject: [PATCH 7/7] Learning rate log callback (#1498) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * base implementation * docs + implementation * fix styling * add lr string * renaming * CHANGELOG.md * add tests * Apply suggestions from code review Co-Authored-By: Adrian Wälchli * Apply suggestions from code review * Update pytorch_lightning/callbacks/lr_logger.py * Update pytorch_lightning/callbacks/lr_logger.py * add test for naming * base implementation * docs + implementation * fix styling * add lr string * renaming * CHANGELOG.md * add tests * Apply suggestions from code review Co-Authored-By: Adrian Wälchli * Apply suggestions from code review * Update pytorch_lightning/callbacks/lr_logger.py * Update pytorch_lightning/callbacks/lr_logger.py * add test for naming * Update pytorch_lightning/callbacks/lr_logger.py Co-Authored-By: Adrian Wälchli * suggestions from code review * fix styling * rebase * fix tests Co-authored-by: Nicki Skafte Co-authored-by: Jirka Borovec Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 2 + docs/source/callbacks.rst | 8 ++ pytorch_lightning/callbacks/__init__.py | 2 + pytorch_lightning/callbacks/lr_logger.py | 118 +++++++++++++++++++++++ tests/callbacks/test_callbacks.py | 59 +++++++++++- 5 files changed, 188 insertions(+), 1 deletion(-) create mode 100755 pytorch_lightning/callbacks/lr_logger.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 10ec061f18b2bf..f67e85a452ff9f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added callback for logging learning rates ([#1498](https://github.com/PyTorchLightning/pytorch-lightning/pull/1498)) + ### Changed ### Deprecated diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index 10323472facd8f..a2969820b2eebc 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -84,3 +84,11 @@ We successfully extended functionality without polluting our super clean .. automodule:: pytorch_lightning.callbacks.progress :noindex: :exclude-members: + +--------- + +.. automodule:: pytorch_lightning.callbacks.lr_logger + :noindex: + :exclude-members: + _extract_lr, + _find_names \ No newline at end of file diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index c232060ca4ecb4..7e8e0ce5bcfef3 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -2,6 +2,7 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.lr_logger import LearningRateLogger from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar __all__ = [ @@ -9,6 +10,7 @@ 'EarlyStopping', 'ModelCheckpoint', 'GradientAccumulationScheduler', + 'LearningRateLogger', 'ProgressBarBase', 'ProgressBar', ] diff --git a/pytorch_lightning/callbacks/lr_logger.py b/pytorch_lightning/callbacks/lr_logger.py new file mode 100755 index 00000000000000..6ad68905bc3417 --- /dev/null +++ b/pytorch_lightning/callbacks/lr_logger.py @@ -0,0 +1,118 @@ +r""" + +Logging of learning rates +========================= + +Log learning rate for lr schedulers during training + +""" + +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class LearningRateLogger(Callback): + r""" + Automatically logs learning rate for learning rate schedulers during training. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import LearningRateLogger + >>> lr_logger = LearningRateLogger() + >>> trainer = Trainer(callbacks=[lr_logger]) + + Logging names are automatically determined based on optimizer class name. + In case of multiple optimizers of same type, they will be named `Adam`, + `Adam-1` etc. If a optimizer has multiple parameter groups they will + be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a + `name` keyword in the construction of the learning rate schdulers + + Example:: + + def configure_optimizer(self): + optimizer = torch.optim.Adam(...) + lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...) + 'name': 'my_logging_name'} + return [optimizer], [lr_scheduler] + """ + def __init__(self): + self.lrs = None + self.lr_sch_names = [] + + def on_train_start(self, trainer, pl_module): + """ Called before training, determines unique names for all lr + schedulers in the case of multiple of the same type or in + the case of multiple parameter groups + """ + if trainer.lr_schedulers == []: + raise MisconfigurationException( + 'Cannot use LearningRateLogger callback with models that have no' + ' learning rate schedulers. Please see documentation for' + ' `configure_optimizers` method.') + + if not trainer.logger: + raise MisconfigurationException( + 'Cannot use LearningRateLogger callback with Trainer that has no logger.') + + # Find names for schedulers + names = self._find_names(trainer.lr_schedulers) + + # Initialize for storing values + self.lrs = dict.fromkeys(names, []) + + def on_batch_start(self, trainer, pl_module): + latest_stat = self._extract_lr(trainer, 'step') + if trainer.logger and latest_stat: + trainer.logger.log_metrics(latest_stat, step=trainer.global_step) + + def on_epoch_start(self, trainer, pl_module): + latest_stat = self._extract_lr(trainer, 'epoch') + if trainer.logger and latest_stat: + trainer.logger.log_metrics(latest_stat, step=trainer.global_step) + + def _extract_lr(self, trainer, interval): + """ Extracts learning rates for lr schedulers and saves information + into dict structure. """ + latest_stat = {} + for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): + if scheduler['interval'] == interval: + param_groups = scheduler['scheduler'].optimizer.param_groups + if len(param_groups) != 1: + for i, pg in enumerate(param_groups): + lr, key = pg['lr'], f'{name}/{i + 1}' + self.lrs[key].append(lr) + latest_stat[key] = lr + else: + self.lrs[name].append(param_groups[0]['lr']) + latest_stat[name] = param_groups[0]['lr'] + return latest_stat + + def _find_names(self, lr_schedulers): + # Create uniqe names in the case we have multiple of the same learning + # rate schduler + multiple parameter groups + names = [] + for scheduler in lr_schedulers: + sch = scheduler['scheduler'] + if 'name' in scheduler: + name = scheduler['name'] + else: + opt_name = 'lr-' + sch.optimizer.__class__.__name__ + i, name = 1, opt_name + # Multiple schduler of the same type + while True: + if name not in names: + break + i, name = i + 1, f'{opt_name}-{i}' + + # Multiple param groups for the same schduler + param_groups = sch.optimizer.param_groups + if len(param_groups) != 1: + for i, pg in enumerate(param_groups): + temp = name + '/pg' + str(i + 1) + names.append(temp) + else: + names.append(name) + + self.lr_sch_names.append(name) + return names diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 9dba21eab07d87..a082c5ec6f1a65 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -2,11 +2,12 @@ import tests.base.utils as tutils from pytorch_lightning import Callback from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint from tests.base import ( LightTrainDataloader, LightTestMixin, LightValidationMixin, + LightTestOptimizersWithMixedSchedulingMixin, TestModelBase ) @@ -273,3 +274,59 @@ class CurrentTestModel(LightTrainDataloader, TestModelBase): # These should be different if the dirpath has be overridden assert trainer.ckpt_path != trainer.default_root_dir + + +def test_lr_logger_single_lr(tmpdir): + """ Test that learning rates are extracted and logged for single lr scheduler""" + tutils.reset_seed() + + class CurrentTestModel(LightTrainDataloader, TestModelBase): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + lr_logger = LearningRateLogger() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=5, + val_percent_check=0.1, + train_percent_check=0.5, + callbacks=[lr_logger] + ) + results = trainer.fit(model) + + assert lr_logger.lrs, 'No learning rates logged' + assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ + 'Number of learning rates logged does not match number of lr schedulers' + assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \ + 'Names of learning rates not set correctly' + + +def test_lr_logger_multi_lrs(tmpdir): + """ Test that learning rates are extracted and logged for multi lr schedulers """ + tutils.reset_seed() + + class CurrentTestModel(LightTestOptimizersWithMixedSchedulingMixin, + LightTrainDataloader, + TestModelBase): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + lr_logger = LearningRateLogger() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.5, + callbacks=[lr_logger] + ) + results = trainer.fit(model) + + assert lr_logger.lrs, 'No learning rates logged' + assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ + 'Number of learning rates logged does not match number of lr schedulers' + assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \ + 'Names of learning rates not set correctly'