Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored training_batch + tests to verify correctness #2328

Merged
merged 11 commits into from
Jun 23, 2020
220 changes: 145 additions & 75 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.memory import recursive_detach
import subprocess

try:
Expand Down Expand Up @@ -527,8 +529,14 @@ def run_training_epoch(self):
_processed_outputs = self.process_output(epoch_output)
log_epoch_metrics = _processed_outputs[2]
callback_epoch_metrics = _processed_outputs[3]

# add the metrics to the loggers
self.log_metrics(log_epoch_metrics, {})

# add metrics to callbacks
self.callback_metrics.update(callback_epoch_metrics)

# add metrics to progress_bar
self.add_progress_bar_metrics(_processed_outputs[1])

# when no val loop is present or fast-dev-run still need to call checkpoints
Expand All @@ -548,10 +556,10 @@ def run_training_batch(self, batch, batch_idx):
grad_norm_dic = {}

# track all metrics for callbacks
all_callback_metrics = []
batch_callback_metrics = []

# track metrics to log
all_log_metrics = []
batch_log_metrics = []

if batch is None:
return 0, grad_norm_dic, {}, {}
Expand Down Expand Up @@ -586,87 +594,42 @@ def run_training_batch(self, batch, batch_idx):
for param in group['params']:
param.requires_grad = True

# wrap the forward step in a closure so second order methods work
def optimizer_closure():
# forward pass
with self.profiler.profile('model_forward'):
if self.use_amp and self.use_native_amp:
with torch.cuda.amp.autocast():
output_dict = self.training_forward(split_batch, batch_idx,
opt_idx, self.hiddens)
else:
output_dict = self.training_forward(split_batch, batch_idx, opt_idx, self.hiddens)

# format and reduce outputs accordingly
processed_output = self.process_output(output_dict, train=True)

closure_loss, progress_bar_metrics, log_metrics, callback_metrics, self.hiddens = processed_output

# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
closure_loss = closure_loss / self.accumulate_grad_batches

# backward pass
model_ref = self.get_model()
with self.profiler.profile('model_backward'):
# scale loss for 16 bit
if self.precision == 16 and not self.on_tpu:
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx)

# do backward pass
model_ref.backward(self, closure_loss, optimizer, opt_idx)

# track metrics for callbacks
all_callback_metrics.append(callback_metrics)

# track progress bar metrics
self.add_progress_bar_metrics(progress_bar_metrics)
all_log_metrics.append(log_metrics)

if self.use_horovod:
# Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
optimizer.synchronize()

# insert after step hook
if self.is_function_implemented('on_after_backward'):
model_ref = self.get_model()
with self.profiler.profile('on_after_backward'):
model_ref.on_after_backward()
# -------------------
# calculate loss
# -------------------
opt_closure_result = self.optimizer_closure(
split_batch,
batch_idx,
opt_idx,
optimizer,
self.hiddens
)

return closure_loss, callback_metrics
# ------------------------------
# POST forward bookkeeping
# ------------------------------
batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics)
batch_log_metrics.append(opt_closure_result.training_step_output.log_metrics)
self.add_progress_bar_metrics(opt_closure_result.training_step_output.pbar_on_batch_end)

# calculate loss
loss, batch_output = optimizer_closure()
# track hiddens
self.hiddens = opt_closure_result.hiddens

# check if loss or model weights are nan
if self.terminate_on_nan:
self.detect_nan_tensors(loss)
self.detect_nan_tensors(opt_closure_result.loss)

# track total loss for logging (avoid mem leaks)
self.batch_loss_value.append(loss)
self.batch_loss_value.append(opt_closure_result.loss)

# ------------------------------
# BACKWARD PASS
# ------------------------------
# gradient update with accumulated gradients
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:

# track gradient norms when requested
if batch_idx % self.row_log_interval == 0:
if float(self.track_grad_norm) > 0:
model = self.get_model()
grad_norm_dic = model.grad_norm(
self.track_grad_norm)

# clip gradients
if self.use_amp and self.use_native_amp:
self.scaler.unscale_(optimizer)
self.clip_gradients()

# calls .step(), .zero_grad()
# override function to modify this behavior
model = self.get_model()
with self.profiler.profile('optimizer_step'):
model.optimizer_step(self.current_epoch, batch_idx,
optimizer, opt_idx,
lambda: optimizer_closure()[0])
# backward
grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer)

# calculate running loss for display
self.running_loss.append(self.batch_loss_value.mean())
Expand All @@ -683,12 +646,119 @@ def optimizer_closure():
self.get_model().on_batch_end()

# collapse all metrics into one dict
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}

# track all metrics for callbacks
self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()})
self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()})

return 0, grad_norm_dic, all_log_metrics, batch_output
return 0, grad_norm_dic, batch_log_metrics, opt_closure_result.training_step_output_for_epoch_end

def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
# ------------------
# GRAD NORMS
# ------------------
# track gradient norms when requested
grad_norm_dic = {}
if batch_idx % self.row_log_interval == 0:
if float(self.track_grad_norm) > 0:
model = self.get_model()
grad_norm_dic = model.grad_norm(
self.track_grad_norm)

# ------------------
# CLIP GRADS
# ------------------
if self.use_amp and self.use_native_amp:
self.scaler.unscale_(optimizer)
self.clip_gradients()

# ------------------
# .STEP + ZERO_GRAD
# ------------------
model = self.get_model()
with self.profiler.profile('optimizer_step'):
lambda_closure = lambda: self.optimizer_closure(
split_batch,
batch_idx,
opt_idx,
optimizer,
self.hiddens
).loss
model.optimizer_step(self.current_epoch, batch_idx,
optimizer, opt_idx,
lambda_closure)

return grad_norm_dic

def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
"""
wrap the forward step in a closure so second order methods work
"""
# ---------------------------
# FORWARD
# ---------------------------
with self.profiler.profile('model_forward'):
if self.use_amp and self.use_native_amp:
with torch.cuda.amp.autocast():
training_step_output = self.training_forward(split_batch, batch_idx,
opt_idx, hiddens)
else:
training_step_output = self.training_forward(split_batch, batch_idx, opt_idx,
hiddens)

# ----------------------------
# PROCESS THE RESULT
# ----------------------------
# format and reduce outputs accordingly
training_step_output = self.process_output(training_step_output, train=True)

# TODO: temporary part of structured results PR
training_step_output = AttributeDict(
batch_loss=training_step_output[0],
pbar_on_batch_end=training_step_output[1],
log_metrics=training_step_output[2],
callback_metrics=training_step_output[3],
hiddens=training_step_output[4],
)

# if the user decides to finally reduce things in epoch_end, save raw output without graphs
training_step_output_for_epoch_end = recursive_detach(training_step_output)

# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches

# backward pass
model_ref = self.get_model()
with self.profiler.profile('model_backward'):
# scale loss for 16 bit
if self.precision == 16 and not self.on_tpu:
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx)

# do backward pass
model_ref.backward(self, closure_loss, optimizer, opt_idx)

# once backward has been applied, release graph
closure_loss = closure_loss.detach()
training_step_output.batch_loss = training_step_output.batch_loss.detach()

if self.use_horovod:
# Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
optimizer.synchronize()

# insert after step hook
if self.is_function_implemented('on_after_backward'):
model_ref = self.get_model()
with self.profiler.profile('on_after_backward'):
model_ref.on_after_backward()

result = AttributeDict(
loss=closure_loss,
training_step_output=training_step_output,
training_step_output_for_epoch_end=training_step_output_for_epoch_end,
hiddens=training_step_output.hiddens,
)
return result

def _get_optimizers_iterable(self):
if not self.optimizer_frequencies:
Expand Down
Loading