From 6c56a26821d7539d8e0e368b7e911e187ed4afed Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Tue, 18 Feb 2020 15:45:24 +0100 Subject: [PATCH 1/6] refactor, cut-out train epoch --- pytorch_lightning/trainer/training_loop.py | 173 ++++++++++----------- 1 file changed, 85 insertions(+), 88 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e59183aa6b94e..324f1d2bc45a9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -169,16 +169,9 @@ def training_step(self, batch, batch_idx): try: import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -except ImportError: - XLA_AVAILABLE = False - -try: import torch_xla.distributed.parallel_loader as xla_pl XLA_AVAILABLE = True - except ImportError: XLA_AVAILABLE = False @@ -304,93 +297,97 @@ def process_output(self, output, train): # this is just empty shell for code from other class pass + def _training_epoch(self, epoch: int): + """Run single training epoch with `epoch` index.""" + # set seed for distributed sampler (enables shuffling for each epoch) + if (self.use_ddp or self.use_tpu) \ + and hasattr(self.get_train_dataloader().sampler, 'set_epoch'): + self.get_train_dataloader().sampler.set_epoch(epoch) + + # get model + model = self.get_model() + + # update training progress in trainer and model + model.current_epoch = epoch + self.current_epoch = epoch + + total_val_batches = 0 + is_val_epoch = False + if not self.disable_validation: + # val can be checked multiple times in epoch + is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 + val_checks_per_epoch = self.num_training_batches // self.val_check_batch + val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 + total_val_batches = self.num_val_batches * val_checks_per_epoch + + # total batches includes multiple val checks + self.total_batches = self.num_training_batches + total_val_batches + self.batch_loss_value = 0 # accumulated grads + + if self.fast_dev_run: + # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run + num_iterations = 2 + elif self.is_iterable_train_dataloader: + # for iterable train loader, the progress bar never ends + num_iterations = None + else: + num_iterations = self.total_batches + + # reset progress bar + # .reset() doesn't work on disabled progress bar so we should check + if not self.main_progress_bar.disable: + self.main_progress_bar.reset(num_iterations) + desc = f"Epoch {epoch + 1}" if not self.is_iterable_train_dataloader else "" + self.main_progress_bar.set_description(desc) + + # changing gradient according accumulation_scheduler + self.accumulation_scheduler.on_epoch_begin() + + # ----------------- + # RUN TRAINING EPOCH + # ----------------- + self.run_training_epoch() + + # update LR schedulers + if self.lr_schedulers: + for lr_scheduler in self.lr_schedulers: + lr_scheduler.step() + if self.reduce_lr_on_plateau_scheduler is not None: + val_loss = self.callback_metrics.get('val_loss') + if val_loss is None: + avail_metrics = ','.join(list(self.callback_metrics.keys())) + m = "ReduceLROnPlateau conditioned on metric val_loss" \ + f" which is not available. Available metrics are: {avail_metrics}" + raise MisconfigurationException(m) + self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch) + + if self.max_steps and self.max_steps == self.global_step: + self.main_progress_bar.close() + model.on_train_end() + return + + # early stopping + met_min_epochs = epoch >= self.min_epochs - 1 + met_min_steps = self.global_step >= self.min_steps if self.min_steps else True + + if (self.enable_early_stop and not self.disable_validation and is_val_epoch and + ((met_min_epochs and met_min_steps) or self.fast_dev_run)): + should_stop = self.early_stop_callback.on_epoch_end() + # stop training + stop = should_stop and met_min_epochs + if stop: + self.main_progress_bar.close() + with self.profiler.profile('on_train_end'): + model.on_train_end() + return + def train(self): warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,' ' but will start from "0" in v0.8.0.', DeprecationWarning) model = self.get_model() # run all epochs for epoch in range(self.current_epoch, self.max_epochs): - # set seed for distributed sampler (enables shuffling for each epoch) - if (self.use_ddp or self.use_tpu) \ - and hasattr(self.get_train_dataloader().sampler, 'set_epoch'): - self.get_train_dataloader().sampler.set_epoch(epoch) - - # get model - model = self.get_model() - - # update training progress in trainer and model - model.current_epoch = epoch - self.current_epoch = epoch - - total_val_batches = 0 - is_val_epoch = False - if not self.disable_validation: - # val can be checked multiple times in epoch - is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - val_checks_per_epoch = self.num_training_batches // self.val_check_batch - val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 - total_val_batches = self.num_val_batches * val_checks_per_epoch - - # total batches includes multiple val checks - self.total_batches = self.num_training_batches + total_val_batches - self.batch_loss_value = 0 # accumulated grads - - if self.fast_dev_run: - # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run - num_iterations = 2 - elif self.is_iterable_train_dataloader: - # for iterable train loader, the progress bar never ends - num_iterations = None - else: - num_iterations = self.total_batches - - # reset progress bar - # .reset() doesn't work on disabled progress bar so we should check - if not self.main_progress_bar.disable: - self.main_progress_bar.reset(num_iterations) - desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else '' - self.main_progress_bar.set_description(desc) - - # changing gradient according accumulation_scheduler - self.accumulation_scheduler.on_epoch_begin() - - # ----------------- - # RUN TNG EPOCH - # ----------------- - self.run_training_epoch() - - # update LR schedulers - if self.lr_schedulers is not None: - for lr_scheduler in self.lr_schedulers: - lr_scheduler.step() - if self.reduce_lr_on_plateau_scheduler is not None: - val_loss = self.callback_metrics.get('val_loss') - if val_loss is None: - avail_metrics = ','.join(list(self.callback_metrics.keys())) - m = f'ReduceLROnPlateau conditioned on metric val_loss ' \ - f'which is not available. Available metrics are: {avail_metrics}' - raise MisconfigurationException(m) - self.reduce_lr_on_plateau_scheduler.step(val_loss) - - if self.max_steps and self.max_steps == self.global_step: - self.main_progress_bar.close() - model.on_train_end() - return - - # early stopping - met_min_epochs = epoch >= self.min_epochs - 1 - met_min_steps = self.global_step >= self.min_steps if self.min_steps else True - - if (self.enable_early_stop and not self.disable_validation and is_val_epoch and - ((met_min_epochs and met_min_steps) or self.fast_dev_run)): - should_stop = self.early_stop_callback.on_epoch_end() - # stop training - stop = should_stop and met_min_epochs - if stop: - self.main_progress_bar.close() - with self.profiler.profile('on_train_end'): - model.on_train_end() - return + self._training_epoch(epoch) self.main_progress_bar.close() From 05178eddc00a82b54274a7dff19658df13ea86a6 Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Tue, 18 Feb 2020 16:44:37 +0100 Subject: [PATCH 2/6] add timeout --- .github/workflows/ci-testing.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 4e6c2685325b5..ec5d7b05c56e0 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -3,6 +3,9 @@ name: CI testing on: [push, pull_request] jobs: + # https://stackoverflow.com/a/59076067/4521646 + timeout-minutes: 20 + build: runs-on: ${{ matrix.os }} From 187f382003d988d15d7ee82e45f3dc14c3246f84 Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Tue, 18 Feb 2020 17:10:50 +0100 Subject: [PATCH 3/6] add timeout --- .github/workflows/ci-testing.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index ec5d7b05c56e0..26a3e99ece84d 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -3,9 +3,6 @@ name: CI testing on: [push, pull_request] jobs: - # https://stackoverflow.com/a/59076067/4521646 - timeout-minutes: 20 - build: runs-on: ${{ matrix.os }} @@ -17,6 +14,8 @@ jobs: python-version: [3.6, 3.7] requires: ['minimal', 'latest'] + # https://stackoverflow.com/a/59076067/4521646 + timeout-minutes: 20 steps: - uses: actions/checkout@v1 - name: Set up Python ${{ matrix.python-version }} From 87d1d858d7f26592fba2974d8eaf742482ff0038 Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Tue, 18 Feb 2020 17:13:29 +0100 Subject: [PATCH 4/6] multiple --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 324f1d2bc45a9..4f5892582c4ee 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -359,7 +359,7 @@ def _training_epoch(self, epoch: int): m = "ReduceLROnPlateau conditioned on metric val_loss" \ f" which is not available. Available metrics are: {avail_metrics}" raise MisconfigurationException(m) - self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch) + self.reduce_lr_on_plateau_scheduler.step(val_loss) if self.max_steps and self.max_steps == self.global_step: self.main_progress_bar.close() From 1b31c3f342cbf23b7686acd8134a35094157934a Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Tue, 18 Feb 2020 23:02:24 +0100 Subject: [PATCH 5/6] try test --- pytorch_lightning/trainer/training_loop.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4f5892582c4ee..224a13d037eb6 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -162,18 +162,18 @@ def training_step(self, batch, batch_idx): try: from apex import amp - - APEX_AVAILABLE = True except ImportError: APEX_AVAILABLE = False +else: + APEX_AVAILABLE = True try: import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as xla_pl - - XLA_AVAILABLE = True except ImportError: XLA_AVAILABLE = False +else: + XLA_AVAILABLE = True class TrainerTrainLoopMixin(ABC): @@ -349,7 +349,7 @@ def _training_epoch(self, epoch: int): self.run_training_epoch() # update LR schedulers - if self.lr_schedulers: + if self.lr_schedulers is not None: for lr_scheduler in self.lr_schedulers: lr_scheduler.step() if self.reduce_lr_on_plateau_scheduler is not None: From 59e707d5b7c28325fbe0ca548aa87e4ece862f62 Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Tue, 18 Feb 2020 23:32:19 +0100 Subject: [PATCH 6/6] try test --- pytorch_lightning/trainer/training_loop.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 224a13d037eb6..a4b0ee9a8a93f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -351,7 +351,10 @@ def _training_epoch(self, epoch: int): # update LR schedulers if self.lr_schedulers is not None: for lr_scheduler in self.lr_schedulers: + # while lr_scheduler.last_epoch < self.current_epoch: + # lr_scheduler.step() lr_scheduler.step() + if self.reduce_lr_on_plateau_scheduler is not None: val_loss = self.callback_metrics.get('val_loss') if val_loss is None: @@ -359,6 +362,8 @@ def _training_epoch(self, epoch: int): m = "ReduceLROnPlateau conditioned on metric val_loss" \ f" which is not available. Available metrics are: {avail_metrics}" raise MisconfigurationException(m) + # while self.reduce_lr_on_plateau_scheduler.last_epoch < self.current_epoch: + # self.reduce_lr_on_plateau_scheduler.step(val_loss) self.reduce_lr_on_plateau_scheduler.step(val_loss) if self.max_steps and self.max_steps == self.global_step: