Skip to content

Commit

Permalink
refactored training_batch + tests to verify correctness (#2328)
Browse files Browse the repository at this point in the history
* refactored training_bath

* refactored training_bath

* refactored training_bath

* refactored training_bath

* refactored training_bath

* refactored training_bath

* refactored training_bath

* refactored training_bath

* refactored training_bath

* refactored training_bath

* refactored training_bath
  • Loading branch information
williamFalcon committed Jun 23, 2020
1 parent 29179db commit 0f07381
Show file tree
Hide file tree
Showing 3 changed files with 428 additions and 75 deletions.
220 changes: 145 additions & 75 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.memory import recursive_detach
import subprocess

try:
Expand Down Expand Up @@ -527,8 +529,14 @@ def run_training_epoch(self):
_processed_outputs = self.process_output(epoch_output)
log_epoch_metrics = _processed_outputs[2]
callback_epoch_metrics = _processed_outputs[3]

# add the metrics to the loggers
self.log_metrics(log_epoch_metrics, {})

# add metrics to callbacks
self.callback_metrics.update(callback_epoch_metrics)

# add metrics to progress_bar
self.add_progress_bar_metrics(_processed_outputs[1])

# when no val loop is present or fast-dev-run still need to call checkpoints
Expand All @@ -548,10 +556,10 @@ def run_training_batch(self, batch, batch_idx):
grad_norm_dic = {}

# track all metrics for callbacks
all_callback_metrics = []
batch_callback_metrics = []

# track metrics to log
all_log_metrics = []
batch_log_metrics = []

if batch is None:
return 0, grad_norm_dic, {}, {}
Expand Down Expand Up @@ -586,87 +594,42 @@ def run_training_batch(self, batch, batch_idx):
for param in group['params']:
param.requires_grad = True

# wrap the forward step in a closure so second order methods work
def optimizer_closure():
# forward pass
with self.profiler.profile('model_forward'):
if self.use_amp and self.use_native_amp:
with torch.cuda.amp.autocast():
output_dict = self.training_forward(split_batch, batch_idx,
opt_idx, self.hiddens)
else:
output_dict = self.training_forward(split_batch, batch_idx, opt_idx, self.hiddens)

# format and reduce outputs accordingly
processed_output = self.process_output(output_dict, train=True)

closure_loss, progress_bar_metrics, log_metrics, callback_metrics, self.hiddens = processed_output

# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
closure_loss = closure_loss / self.accumulate_grad_batches

# backward pass
model_ref = self.get_model()
with self.profiler.profile('model_backward'):
# scale loss for 16 bit
if self.precision == 16 and not self.on_tpu:
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx)

# do backward pass
model_ref.backward(self, closure_loss, optimizer, opt_idx)

# track metrics for callbacks
all_callback_metrics.append(callback_metrics)

# track progress bar metrics
self.add_progress_bar_metrics(progress_bar_metrics)
all_log_metrics.append(log_metrics)

if self.use_horovod:
# Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
optimizer.synchronize()

# insert after step hook
if self.is_function_implemented('on_after_backward'):
model_ref = self.get_model()
with self.profiler.profile('on_after_backward'):
model_ref.on_after_backward()
# -------------------
# calculate loss
# -------------------
opt_closure_result = self.optimizer_closure(
split_batch,
batch_idx,
opt_idx,
optimizer,
self.hiddens
)

return closure_loss, callback_metrics
# ------------------------------
# POST forward bookkeeping
# ------------------------------
batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics)
batch_log_metrics.append(opt_closure_result.training_step_output.log_metrics)
self.add_progress_bar_metrics(opt_closure_result.training_step_output.pbar_on_batch_end)

# calculate loss
loss, batch_output = optimizer_closure()
# track hiddens
self.hiddens = opt_closure_result.hiddens

# check if loss or model weights are nan
if self.terminate_on_nan:
self.detect_nan_tensors(loss)
self.detect_nan_tensors(opt_closure_result.loss)

# track total loss for logging (avoid mem leaks)
self.batch_loss_value.append(loss)
self.batch_loss_value.append(opt_closure_result.loss)

# ------------------------------
# BACKWARD PASS
# ------------------------------
# gradient update with accumulated gradients
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:

# track gradient norms when requested
if batch_idx % self.row_log_interval == 0:
if float(self.track_grad_norm) > 0:
model = self.get_model()
grad_norm_dic = model.grad_norm(
self.track_grad_norm)

# clip gradients
if self.use_amp and self.use_native_amp:
self.scaler.unscale_(optimizer)
self.clip_gradients()

# calls .step(), .zero_grad()
# override function to modify this behavior
model = self.get_model()
with self.profiler.profile('optimizer_step'):
model.optimizer_step(self.current_epoch, batch_idx,
optimizer, opt_idx,
lambda: optimizer_closure()[0])
# backward
grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer)

# calculate running loss for display
self.running_loss.append(self.batch_loss_value.mean())
Expand All @@ -683,12 +646,119 @@ def optimizer_closure():
self.get_model().on_batch_end()

# collapse all metrics into one dict
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}

# track all metrics for callbacks
self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()})
self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()})

return 0, grad_norm_dic, all_log_metrics, batch_output
return 0, grad_norm_dic, batch_log_metrics, opt_closure_result.training_step_output_for_epoch_end

def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
# ------------------
# GRAD NORMS
# ------------------
# track gradient norms when requested
grad_norm_dic = {}
if batch_idx % self.row_log_interval == 0:
if float(self.track_grad_norm) > 0:
model = self.get_model()
grad_norm_dic = model.grad_norm(
self.track_grad_norm)

# ------------------
# CLIP GRADS
# ------------------
if self.use_amp and self.use_native_amp:
self.scaler.unscale_(optimizer)
self.clip_gradients()

# ------------------
# .STEP + ZERO_GRAD
# ------------------
model = self.get_model()
with self.profiler.profile('optimizer_step'):
lambda_closure = lambda: self.optimizer_closure(
split_batch,
batch_idx,
opt_idx,
optimizer,
self.hiddens
).loss
model.optimizer_step(self.current_epoch, batch_idx,
optimizer, opt_idx,
lambda_closure)

return grad_norm_dic

def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
"""
wrap the forward step in a closure so second order methods work
"""
# ---------------------------
# FORWARD
# ---------------------------
with self.profiler.profile('model_forward'):
if self.use_amp and self.use_native_amp:
with torch.cuda.amp.autocast():
training_step_output = self.training_forward(split_batch, batch_idx,
opt_idx, hiddens)
else:
training_step_output = self.training_forward(split_batch, batch_idx, opt_idx,
hiddens)

# ----------------------------
# PROCESS THE RESULT
# ----------------------------
# format and reduce outputs accordingly
training_step_output = self.process_output(training_step_output, train=True)

# TODO: temporary part of structured results PR
training_step_output = AttributeDict(
batch_loss=training_step_output[0],
pbar_on_batch_end=training_step_output[1],
log_metrics=training_step_output[2],
callback_metrics=training_step_output[3],
hiddens=training_step_output[4],
)

# if the user decides to finally reduce things in epoch_end, save raw output without graphs
training_step_output_for_epoch_end = recursive_detach(training_step_output)

# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches

# backward pass
model_ref = self.get_model()
with self.profiler.profile('model_backward'):
# scale loss for 16 bit
if self.precision == 16 and not self.on_tpu:
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx)

# do backward pass
model_ref.backward(self, closure_loss, optimizer, opt_idx)

# once backward has been applied, release graph
closure_loss = closure_loss.detach()
training_step_output.batch_loss = training_step_output.batch_loss.detach()

if self.use_horovod:
# Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
optimizer.synchronize()

# insert after step hook
if self.is_function_implemented('on_after_backward'):
model_ref = self.get_model()
with self.profiler.profile('on_after_backward'):
model_ref.on_after_backward()

result = AttributeDict(
loss=closure_loss,
training_step_output=training_step_output,
training_step_output_for_epoch_end=training_step_output_for_epoch_end,
hiddens=training_step_output.hiddens,
)
return result

def _get_optimizers_iterable(self):
if not self.optimizer_frequencies:
Expand Down
Loading

0 comments on commit 0f07381

Please sign in to comment.