Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor training loop #2336

Merged
merged 8 commits into from
Jun 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 115 additions & 73 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,16 @@ def train(self):

self.run_training_teardown()

def run_training_epoch(self):
def prepare_train_loop_dataloader(self, train_dataloader):
# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu:
device = xm.xla_device(self.tpu_id)
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
train_dataloader = train_dataloader.per_device_loader(device)

# get model
model = self.get_model()
return train_dataloader

def run_on_epoch_start_hook(self, model):
# Epoch start events
with self.profiler.profile('on_epoch_start'):
# callbacks
Expand All @@ -429,17 +434,19 @@ def run_training_epoch(self):
if self.is_function_implemented('on_epoch_start'):
model.on_epoch_start()

# track local dataloader so TPU can wrap each epoch
train_dataloader = self.train_dataloader
def run_training_epoch(self):

# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu:
device = xm.xla_device(self.tpu_id)
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
train_dataloader = train_dataloader.per_device_loader(device)
# get model
model = self.get_model()

# Epoch start events
self.run_on_epoch_start_hook(model)

# modify dataloader if needed (ddp, etc...)
train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader)

# bookkeeping
outputs = []
epoch_output = []

# run epoch
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
Expand All @@ -450,63 +457,41 @@ def run_training_epoch(self):
break

self.batch_idx = batch_idx

model.global_step = self.global_step

# ---------------
# RUN TRAIN STEP
# ---------------
_outputs = self.run_training_batch(batch, batch_idx)
batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs
# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
batch_output = self.run_training_batch(batch, batch_idx)

# only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory
if self.is_overridden('training_epoch_end', model=self.get_model()):
outputs.append(batch_output)
epoch_output.append(batch_output.training_step_output_for_epoch_end)

# update LR schedulers
self.update_train_loop_lr_schedulers()

# when returning -1 from train_step, we end epoch early
early_stop_epoch = batch_result == -1

# TODO: consolidate all actions that need to take place only after
# self.accumulate_grad_batches steps (optimizer step, lr update, global step increment)
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
# update lr
self.update_learning_rates(interval='step')

# ---------------
# RUN VAL STEP
# ---------------
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
can_check_val = not self.disable_validation and can_check_epoch
should_check_val = is_val_check_batch or early_stop_epoch
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and should_check_val

# ---------------
# CHECKPOINTING, EARLY STOPPING
# ---------------
# fast_dev_run always forces val checking after train batch
if self.fast_dev_run or should_check_val:
self.run_evaluation(test_mode=self.testing)
self.call_checkpoint_callback()

# when logs should be saved
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:
if self.is_global_zero and self.logger is not None:
self.logger.save()

# when metrics should be logged
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
if should_log_metrics or self.fast_dev_run:
# logs user requested information to logger
self.log_metrics(batch_step_metrics, grad_norm_dic)
early_stop_epoch = batch_output.signal == -1

# -----------------------------------------
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# -----------------------------------------
should_check_val = self.check_validation_in_train_loop(batch_idx, early_stop_epoch, is_last_batch)

# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
# -----------------------------------------
self.save_loggers_in_training_loop(batch_idx, early_stop_epoch)

# -----------------------------------------
# SAVE METRICS TO LOGGERS
# -----------------------------------------
self.save_train_loop_metrics_to_loggers(batch_idx, early_stop_epoch, batch_output)

# progress global step according to grads progress
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
self.global_step += 1
self.total_batch_idx += 1
self.increment_accumulated_grad_global_step()

# max steps reached, end training
if self.max_steps is not None and self.max_steps == self.global_step:
Expand All @@ -518,13 +503,36 @@ def run_training_epoch(self):
if early_stop_epoch or self.fast_dev_run:
break

if self.use_horovod:
hvd.join(hvd.local_rank() if self.on_gpu else -1)
# let ddp devices catch up when using horovod
self.sync_horovod()

# process epoch outputs
self.run_training_epoch_end(epoch_output)

# when no val loop is present or fast-dev-run still need to call checkpoints
if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val):
self.call_checkpoint_callback()

# epoch end hook
self.run_on_epoch_end_hook(model)

def update_train_loop_lr_schedulers(self):
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
# update lr
self.update_learning_rates(interval='step')

def run_on_epoch_end_hook(self, model):
with self.profiler.profile('on_epoch_end'):
# callbacks
self.on_epoch_end()
# model hooks
if self.is_function_implemented('on_epoch_end'):
model.on_epoch_end()

def run_training_epoch_end(self, epoch_output):
model = self.get_model()
if self.is_overridden('training_epoch_end', model=model):
epoch_output = model.training_epoch_end(outputs)
epoch_output = model.training_epoch_end(epoch_output)
_processed_outputs = self.process_output(epoch_output)
log_epoch_metrics = _processed_outputs[2]
callback_epoch_metrics = _processed_outputs[3]
Expand All @@ -538,17 +546,45 @@ def run_training_epoch(self):
# add metrics to progress_bar
self.add_progress_bar_metrics(_processed_outputs[1])

# when no val loop is present or fast-dev-run still need to call checkpoints
if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val):
def sync_horovod(self):
if self.use_horovod:
hvd.join(hvd.local_rank() if self.on_gpu else -1)

def increment_accumulated_grad_global_step(self):
# progress global step according to grads progress
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
self.global_step += 1
self.total_batch_idx += 1

def save_train_loop_metrics_to_loggers(self, batch_idx, early_stop_epoch, batch_output):
# when metrics should be logged
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
if should_log_metrics or self.fast_dev_run:
# logs user requested information to logger
self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic)

def save_loggers_in_training_loop(self, batch_idx, early_stop_epoch):
# when loggers should save to disk
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:
if self.is_global_zero and self.logger is not None:
self.logger.save()

def check_validation_in_train_loop(self, batch_idx, early_stop_epoch, is_last_batch):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
can_check_val = not self.disable_validation and can_check_epoch
should_check_val = is_val_check_batch or early_stop_epoch
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and should_check_val

# if we need to run validation, then also call the checkpoint callback
if self.fast_dev_run or should_check_val:
self.run_evaluation(test_mode=self.testing)
self.call_checkpoint_callback()

# Epoch end events
with self.profiler.profile('on_epoch_end'):
# callbacks
self.on_epoch_end()
# model hooks
if self.is_function_implemented('on_epoch_end'):
model.on_epoch_end()
return should_check_val

def run_training_batch(self, batch, batch_idx):
# track grad norms
Expand All @@ -561,7 +597,7 @@ def run_training_batch(self, batch, batch_idx):
batch_log_metrics = []

if batch is None:
return 0, grad_norm_dic, {}, {}
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)

# Batch start events
with self.profiler.profile('on_batch_start'):
Expand All @@ -571,7 +607,7 @@ def run_training_batch(self, batch, batch_idx):
if self.is_function_implemented('on_batch_start'):
response = self.get_model().on_batch_start(batch)
if response == -1:
return -1, grad_norm_dic, {}, {}
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)

splits = [batch]
if self.truncated_bptt_steps is not None:
Expand Down Expand Up @@ -650,7 +686,13 @@ def run_training_batch(self, batch, batch_idx):
# track all metrics for callbacks
self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()})

return 0, grad_norm_dic, batch_log_metrics, opt_closure_result.training_step_output_for_epoch_end
result = AttributeDict(
signal=0,
grad_norm_dic=grad_norm_dic,
batch_log_metrics=batch_log_metrics,
training_step_output_for_epoch_end=opt_closure_result.training_step_output_for_epoch_end
)
return result

def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
# ------------------
Expand Down
36 changes: 16 additions & 20 deletions tests/trainer/test_trainer_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ def test_trainingstep_dict(tmpdir):
break

out = trainer.run_training_batch(batch, batch_idx)
signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out
assert signal == 0
assert all_log_metrics['log_acc1'] == 12.0
assert all_log_metrics['log_acc2'] == 7.0
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0

pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end']
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0

Expand All @@ -55,12 +54,11 @@ def training_step_with_step_end(tmpdir):
break

out = trainer.run_training_batch(batch, batch_idx)
signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out
assert signal == 0
assert all_log_metrics['log_acc1'] == 12.0
assert all_log_metrics['log_acc2'] == 7.0
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0

pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end']
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0

Expand Down Expand Up @@ -92,12 +90,11 @@ def test_full_training_loop_dict(tmpdir):
break

out = trainer.run_training_batch(batch, batch_idx)
signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out
assert signal == 0
assert all_log_metrics['log_acc1'] == 12.0
assert all_log_metrics['log_acc2'] == 7.0
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0

pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end']
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0

Expand Down Expand Up @@ -129,11 +126,10 @@ def test_train_step_epoch_end(tmpdir):
break

out = trainer.run_training_batch(batch, batch_idx)
signal, grad_norm_dic, all_log_metrics, training_step_output_for_epoch_end = out
assert signal == 0
assert all_log_metrics['log_acc1'] == 12.0
assert all_log_metrics['log_acc2'] == 7.0
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0

pbar_metrics = training_step_output_for_epoch_end['pbar_on_batch_end']
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0