Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cut-out train step [wip] #891

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ jobs:
python-version: [3.6, 3.7]
requires: ['minimal', 'latest']

# https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 20
steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
Expand Down
186 changes: 94 additions & 92 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,25 +162,18 @@ def training_step(self, batch, batch_idx):

try:
from apex import amp

APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True

try:
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
except ImportError:
XLA_AVAILABLE = False

try:
import torch_xla.distributed.parallel_loader as xla_pl

XLA_AVAILABLE = True

except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True


class TrainerTrainLoopMixin(ABC):
Expand Down Expand Up @@ -304,93 +297,102 @@ def process_output(self, output, train):
# this is just empty shell for code from other class
pass

def _training_epoch(self, epoch: int):
"""Run single training epoch with `epoch` index."""
# set seed for distributed sampler (enables shuffling for each epoch)
if (self.use_ddp or self.use_tpu) \
and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
self.get_train_dataloader().sampler.set_epoch(epoch)

# get model
model = self.get_model()

# update training progress in trainer and model
model.current_epoch = epoch
self.current_epoch = epoch

total_val_batches = 0
is_val_epoch = False
if not self.disable_validation:
# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
total_val_batches = self.num_val_batches * val_checks_per_epoch

# total batches includes multiple val checks
self.total_batches = self.num_training_batches + total_val_batches
self.batch_loss_value = 0 # accumulated grads

if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
num_iterations = 2
elif self.is_iterable_train_dataloader:
# for iterable train loader, the progress bar never ends
num_iterations = None
else:
num_iterations = self.total_batches

# reset progress bar
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f"Epoch {epoch + 1}" if not self.is_iterable_train_dataloader else ""
self.main_progress_bar.set_description(desc)

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin()

# -----------------
# RUN TRAINING EPOCH
# -----------------
self.run_training_epoch()

# update LR schedulers
if self.lr_schedulers is not None:
for lr_scheduler in self.lr_schedulers:
# while lr_scheduler.last_epoch < self.current_epoch:
# lr_scheduler.step()
lr_scheduler.step()

if self.reduce_lr_on_plateau_scheduler is not None:
val_loss = self.callback_metrics.get('val_loss')
if val_loss is None:
avail_metrics = ','.join(list(self.callback_metrics.keys()))
m = "ReduceLROnPlateau conditioned on metric val_loss" \
f" which is not available. Available metrics are: {avail_metrics}"
raise MisconfigurationException(m)
# while self.reduce_lr_on_plateau_scheduler.last_epoch < self.current_epoch:
# self.reduce_lr_on_plateau_scheduler.step(val_loss)
self.reduce_lr_on_plateau_scheduler.step(val_loss)

if self.max_steps and self.max_steps == self.global_step:
self.main_progress_bar.close()
model.on_train_end()
return

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
((met_min_epochs and met_min_steps) or self.fast_dev_run)):
should_stop = self.early_stop_callback.on_epoch_end()
# stop training
stop = should_stop and met_min_epochs
if stop:
self.main_progress_bar.close()
with self.profiler.profile('on_train_end'):
model.on_train_end()
return

def train(self):
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
model = self.get_model()
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
# set seed for distributed sampler (enables shuffling for each epoch)
if (self.use_ddp or self.use_tpu) \
and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
self.get_train_dataloader().sampler.set_epoch(epoch)

# get model
model = self.get_model()

# update training progress in trainer and model
model.current_epoch = epoch
self.current_epoch = epoch

total_val_batches = 0
is_val_epoch = False
if not self.disable_validation:
# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
total_val_batches = self.num_val_batches * val_checks_per_epoch

# total batches includes multiple val checks
self.total_batches = self.num_training_batches + total_val_batches
self.batch_loss_value = 0 # accumulated grads

if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
num_iterations = 2
elif self.is_iterable_train_dataloader:
# for iterable train loader, the progress bar never ends
num_iterations = None
else:
num_iterations = self.total_batches

# reset progress bar
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
self.main_progress_bar.set_description(desc)

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin()

# -----------------
# RUN TNG EPOCH
# -----------------
self.run_training_epoch()

# update LR schedulers
if self.lr_schedulers is not None:
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step()
if self.reduce_lr_on_plateau_scheduler is not None:
val_loss = self.callback_metrics.get('val_loss')
if val_loss is None:
avail_metrics = ','.join(list(self.callback_metrics.keys()))
m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
f'which is not available. Available metrics are: {avail_metrics}'
raise MisconfigurationException(m)
self.reduce_lr_on_plateau_scheduler.step(val_loss)

if self.max_steps and self.max_steps == self.global_step:
self.main_progress_bar.close()
model.on_train_end()
return

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
((met_min_epochs and met_min_steps) or self.fast_dev_run)):
should_stop = self.early_stop_callback.on_epoch_end()
# stop training
stop = should_stop and met_min_epochs
if stop:
self.main_progress_bar.close()
with self.profiler.profile('on_train_end'):
model.on_train_end()
return
self._training_epoch(epoch)

self.main_progress_bar.close()

Expand Down