Skip to content

Commit

Permalink
Gradient checkpointing (#711)
Browse files Browse the repository at this point in the history
* add option each_epoch_in_separate_process

* typos in description

* comments wording

* h.each_epoch_in_separate_process = True in default

* renamed option to run_epoch_in_child_process to avoid confusion

* flags.run_epoch_in_child_process also set to True in default

* h.run_epoch_in_child_process = True : don't need this config

* replaced lambda function with functools.partial to get read of pylint warning

* gradient checkpointing

* gradient checkpointing

* gradient checkpointing

* remove .ropeproject

* description enhancement

* description cleanup

* gradient checkpoint libraries

* deleted graph edtor and gradient checkpointing libraris from this branch

* log message

* remove BUILD

* added back to master

* logging

* graph_editor and gradient checkpointing libs

* deleted:    graph_editor/BUILD

* readme

* readme

* Copyright of gradient checkpointing

* redo

* redo

* third_party linted

* README

* README

* merge conflict typo

* merge conflict typo

* renaming

* no log level reset

* no log level reset

* logging of step per epoch is no longer correct in the latest train_and_eval mode

* add a bit of verbosity to avoid frustration during graph rebuld

* readme

* readme

* less user discretion

* replaced third party nvgpu with intenal module

* replaced third party nvgpu with intenal module

* replaced third party nvgpu with intenal module

* comments added

* carve out toposort and include it here

* refactor toposort based on this repo reqs

* checkout third party

* minor typo

* cleanup

* cleanup, comments
  • Loading branch information
NikZak authored Sep 21, 2020
1 parent 6112ea1 commit 6ab70e1
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 9 deletions.
22 changes: 22 additions & 0 deletions efficientdet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,4 +369,26 @@ For more instructions about training on TPUs, please refer to the following tuto

* EfficientNet tutorial: https://cloud.google.com/tpu/docs/tutorials/efficientnet

## 11. Reducing Memory Usage when Training EfficientDets on GPU.

EfficientDets use a lot of GPU memory for a few reasons:

* Large input resolution: because resolution is one of the scaling dimension, our resolution tends to be higher, which significantly increase activations (although no parameter increase).
* Large internal activations for backbone: our backbone uses a relatively large expansion ratio (6), causing the large expanded activations.
* Deep BiFPN: our BiFPN has multiple top-down and bottom-up paths, which leads to a lot of intermediate memory usage during training.

To train this model on GPU with low memory there is an experimental option gradient_checkpointing.

Check these links for a high-level idea of what gradient checkpointing is doing:
1. https://github.com/cybertronai/gradient-checkpointing
2. https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9

**gradient_checkpointing: True**

If set to True, strings defined by gradient_checkpointing_list (["Add"] by default) are searched in the tensors names and any tensors that match a string from the list are kept as checkpoints. When this option is used the standard tensorflow.python.ops.gradients method is being replaced with a custom method.

Testing shows that:
* On d4 network with batch-size of 1 (mixed precision enabled) it takes only 1/3.2 of memory with roughly 32% slower computation
* It also allows to compute a d6 network with batch size of 2 (mixed precision enabled) on a 11Gb (2080Ti) GPU

NOTE: this is not an official Google product.
40 changes: 36 additions & 4 deletions efficientdet/det_model_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf

import coco_metric
import efficientdet_arch
import hparams_config
Expand Down Expand Up @@ -153,7 +152,7 @@ def focal_loss(y_pred, y_true, alpha, gamma, normalizer, label_smoothing=0.0):
pred_prob = tf.sigmoid(y_pred)
p_t = (y_true * pred_prob) + ((1 - y_true) * (1 - pred_prob))
alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
modulating_factor = (1.0 - p_t) ** gamma
modulating_factor = (1.0 - p_t)**gamma

# apply label smoothing for cross_entropy for each entry.
y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
Expand Down Expand Up @@ -302,8 +301,7 @@ class and box losses from all levels.
box_loss = tf.add_n(box_losses) if box_losses else 0

total_loss = (
cls_loss +
params['box_loss_weight'] * box_loss +
cls_loss + params['box_loss_weight'] * box_loss +
params['iou_loss_weight'] * box_iou_loss)

return total_loss, cls_loss, box_loss, box_iou_loss
Expand Down Expand Up @@ -347,6 +345,7 @@ def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
params['is_training_bn'] = (mode == tf.estimator.ModeKeys.TRAIN)

if params['use_keras_model']:

def model_fn(inputs):
model = efficientdet_keras.EfficientDetNet(
config=hparams_config.Config(params))
Expand Down Expand Up @@ -418,6 +417,21 @@ def model_fn(inputs):

if params['strategy'] == 'tpu':
optimizer = tf.tpu.CrossShardOptimizer(optimizer)
if params['gradient_checkpointing']:
from third_party.grad_checkpoint import memory_saving_gradients # pylint: disable=import-outside-toplevel
from tensorflow.python.ops import gradients # pylint: disable=import-outside-toplevel

# monkey patch tf.gradients to point to our custom version,
# with automatic checkpoint selection
def gradients_(ys, xs, grad_ys=None, **kwargs):
return memory_saving_gradients.gradients(
ys,
xs,
grad_ys,
checkpoints=params['gradient_checkpointing_list'],
**kwargs)

gradients.__dict__["gradients"] = gradients_

# Batch norm requires update_ops to be added as a train_op dependency.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
Expand Down Expand Up @@ -615,6 +629,24 @@ def before_run(self, run_context):
every_n_iter=params.get('iterations_per_loop', 100),
)
training_hooks.append(logging_hook)

if params["nvgpu_logging"]:
try:
from third_party.tools.nvgpu import gpu_memory_util_message # pylint: disable=import-outside-toplevel

mem_message = tf.py_func(gpu_memory_util_message, [], [tf.string])[0]

logging_hook_nvgpu = tf.estimator.LoggingTensorHook(
tensors={
"mem_message": mem_message,
},
every_n_iter=params.get('iterations_per_loop', 100),
formatter=lambda x: x["mem_message"].decode("utf-8"),
)
training_hooks.append(logging_hook_nvgpu)
except:
logging.error("nvgpu error: nvidia-smi format not recognized")

if params['strategy'] == 'tpu':
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode,
Expand Down
28 changes: 26 additions & 2 deletions efficientdet/hparams_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def add_kv_recursive(k, v):
return {k: [eval_str_fn(vv) for vv in v.split('*')]}
return {k: eval_str_fn(v)}
pos = k.index('.')
return {k[:pos]: add_kv_recursive(k[pos+1:], v)}
return {k[:pos]: add_kv_recursive(k[pos + 1:], v)}

def merge_dict_recursive(target, src):
"""Recursively merge two nested dictionary."""
Expand All @@ -161,7 +161,7 @@ def as_dict(self):
else:
config_dict[k] = copy.deepcopy(v)
return config_dict
# pylint: enable=protected-access
# pylint: enable=protected-access


def default_detection_configs():
Expand Down Expand Up @@ -281,6 +281,30 @@ def default_detection_configs():
h.dataset_type = None
h.positives_momentum = None

# Reduces memory during training
h.gradient_checkpointing = False

# Values that could be used are "Add", "Mul", "Conv2d", "Floor", "Sigmoid",
# and other ops names
# or more specific, e.g. "blocks_10/se/conv2d_1"
# E.g. if you use ["Add", "Sigmoid"] it would automatically checkpoint
# all "Add" and "Sigmoid" ops
# The advantage of adding more ops is that the GPU does not need to recompute
# them during the backward pass and can use them as a base to recompute
# other nodes, so it improves the speed
# The disadvantage of adding more ops is it requires more GPU memory to cache
# the computation
# The default is ["Add"] as it is a "bottleneck" node in the backbone network
# EfficientNet. It has been tested and works reasonably well:
# 1) For d4 network with batch-size of 1 (mixed precision enabled) it takes
# only 1/3.2 of memory with roughly 32% slower computation.
# 2) It allows to train a d6 network with batch size of 2 and mixed precision
# on a 11Gb (2080ti) GPU, without this option there is an OOM error
h.gradient_checkpointing_list = ["Add"]

# enable memory logging for NVIDIA cards
h.nvgpu_logging = False

return h


Expand Down
4 changes: 1 addition & 3 deletions efficientdet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import hparams_config
import utils


flags.DEFINE_string(
'tpu',
default=None,
Expand Down Expand Up @@ -341,8 +340,7 @@ def run_train_and_eval(e):
input_fn=train_input_fn,
max_steps=e * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size)
print('\n =====> Starting evaluation, epoch: %d.' % e)
eval_results = eval_est.evaluate(
input_fn=eval_input_fn, steps=eval_steps)
eval_results = eval_est.evaluate(input_fn=eval_input_fn, steps=eval_steps)
ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)

Expand Down
37 changes: 37 additions & 0 deletions efficientdet/third_party/tools/nvgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,40 @@ def gpu_info():
return XmlDictConfig(root)
except FileNotFoundError:
return None


def gpu_memory_util_message():
"""Provide information about GPUs."""
gpu_info_d = gpu_info()
if gpu_info_d is not None:
mem_used = gpu_info_d['gpu']['fb_memory_usage']['used']
mem_total = gpu_info_d['gpu']['fb_memory_usage']['total']
mem_util = commonsize(mem_used) / commonsize(mem_total)
logstring = ("GPU memory used: {} = {:.1%} ".format(mem_used, mem_util) +
"of total GPU memory: {}".format(mem_total))
return logstring
return None


def commonsize(input_size):
"""Convert memory information to a common size (MiB)."""
const_sizes = {
'B': 1,
'KB': 1e3,
'MB': 1e6,
'GB': 1e9,
'TB': 1e12,
'PB': 1e15,
'KiB': 1024,
'MiB': 1048576,
'GiB': 1073741824
}
input_size = input_size.split(" ")
# convert all to MiB
if input_size[1] != 'MiB':
converted_size = float(input_size[0]) * (
const_sizes[input_size[1]] / 1048576.0)
else:
converted_size = float(input_size[0])

return converted_size

0 comments on commit 6ab70e1

Please sign in to comment.