Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient checkpointing #711

Merged
merged 78 commits into from
Sep 21, 2020
Merged
Show file tree
Hide file tree
Changes from 73 commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
012f92d
add option each_epoch_in_separate_process
NikZak Aug 26, 2020
79ff583
typos in description
NikZak Aug 26, 2020
7ec607a
comments wording
NikZak Aug 26, 2020
0426f58
h.each_epoch_in_separate_process = True in default
NikZak Aug 26, 2020
433167a
renamed option to run_epoch_in_child_process to avoid confusion
NikZak Aug 26, 2020
a2811ff
flags.run_epoch_in_child_process also set to True in default
NikZak Aug 26, 2020
6ec564a
h.run_epoch_in_child_process = True : don't need this config
NikZak Aug 27, 2020
213f5ef
replaced lambda function with functools.partial to get read of pylint…
NikZak Aug 27, 2020
5269096
gradient checkpointing
NikZak Aug 28, 2020
5fb26d1
gradient checkpointing
NikZak Aug 28, 2020
2029e42
gradient checkpointing
NikZak Aug 28, 2020
d2e864a
remove .ropeproject
NikZak Aug 28, 2020
c690eb9
description enhancement
NikZak Aug 28, 2020
8336138
description cleanup
NikZak Aug 31, 2020
d74dc82
gradient checkpoint libraries
NikZak Aug 31, 2020
ef6584a
deleted graph edtor and gradient checkpointing libraris from this branch
NikZak Aug 31, 2020
8ddff72
log message
NikZak Aug 31, 2020
bea39c1
remove BUILD
NikZak Aug 31, 2020
fc3c31f
added back to master
NikZak Aug 31, 2020
1daf75f
logging
NikZak Aug 31, 2020
a098a3c
graph_editor and gradient checkpointing libs
NikZak Aug 31, 2020
7db4091
Merge branch 'gradient_checkpoint_libs'
NikZak Aug 31, 2020
14bb3e1
merge gradient checkpoint to master
NikZak Aug 31, 2020
7adff15
deleted: graph_editor/BUILD
NikZak Aug 31, 2020
9cfe955
readme
NikZak Aug 31, 2020
61f1bad
readme
NikZak Aug 31, 2020
72b85f9
Merge branch 'gradient_checkpoint'
NikZak Aug 31, 2020
3dbee2e
Copyright of gradient checkpointing
NikZak Aug 31, 2020
dbb2066
Merge branch 'gradient_checkpoint_libs'
NikZak Aug 31, 2020
4026376
Merge remote-tracking branch 'origin' into gradient_checkpoint_libs
NikZak Aug 31, 2020
e61e4df
License
NikZak Aug 31, 2020
752717b
Merge branch 'gradient_checkpoint_libs'
NikZak Aug 31, 2020
6b97d6d
redo
NikZak Aug 31, 2020
9694ffb
redo
NikZak Aug 31, 2020
871d8dc
merge with current state
NikZak Sep 7, 2020
a919982
third_party linted
NikZak Sep 7, 2020
06888d1
README
NikZak Sep 7, 2020
7169be1
README
NikZak Sep 7, 2020
0da6486
Merge branch 'master' into gradient_checkpoint
NikZak Sep 7, 2020
b99215d
merge conflict typo
NikZak Sep 7, 2020
36ce273
merge conflict typo
NikZak Sep 7, 2020
294e53d
renaming
NikZak Sep 7, 2020
c0dbafa
Merge branch 'master' into gradient_checkpoint
NikZak Sep 7, 2020
ce4e7cd
no log level reset
NikZak Sep 7, 2020
d266e24
no log level reset
NikZak Sep 7, 2020
7fc4659
Merge branch 'master' of https://github.com/google/automl
NikZak Sep 7, 2020
f1cdb2f
Merge branch 'master' into gradient_checkpoint
NikZak Sep 7, 2020
3babbf3
logging of step per epoch is no longer correct in the latest train_an…
NikZak Sep 7, 2020
5d1dcf6
Merge branch 'master' into gradient_checkpoint
NikZak Sep 7, 2020
e5bda6c
tests rectified
NikZak Sep 8, 2020
08dd162
Merge branch 'master' into gradient_checkpoint
NikZak Sep 8, 2020
d514146
add a bit of verbosity to avoid frustration during graph rebuld
NikZak Sep 8, 2020
d0ad430
Merge branch 'master' into gradient_checkpoint
NikZak Sep 8, 2020
08ea86d
readme
NikZak Sep 8, 2020
c346a8b
readme
NikZak Sep 8, 2020
c9f4ab2
less user discretion
NikZak Sep 8, 2020
f007130
less user discretion
NikZak Sep 8, 2020
285dd5b
replaced third party nvgpu with intenal module
NikZak Sep 9, 2020
cb67af3
replaced third party nvgpu with intenal module
NikZak Sep 9, 2020
363dbe7
Merge branch 'master' into gradient_checkpoint
NikZak Sep 9, 2020
9f99e43
replaced third party nvgpu with intenal module
NikZak Sep 9, 2020
aff920d
replaced third party nvgpu with intenal module
NikZak Sep 9, 2020
ad5edd0
comments added
NikZak Sep 9, 2020
1b5ca8f
Merge branch 'master' into gradient_checkpoint
NikZak Sep 9, 2020
23d63d9
carve out toposort and include it here
NikZak Sep 11, 2020
657c877
Merge branch 'master' into gradient_checkpoint
NikZak Sep 11, 2020
f768439
refactor toposort based on this repo reqs
NikZak Sep 11, 2020
9771a8b
Merge branch 'master' into gradient_checkpoint
NikZak Sep 11, 2020
6026eae
checkout third party
NikZak Sep 19, 2020
697b0aa
solved merge upstream conflicts
NikZak Sep 19, 2020
e3dcadb
Merge remote-tracking branch 'upstream/master'
NikZak Sep 19, 2020
6eecfca
minor typo
NikZak Sep 19, 2020
5214c15
Merge branch 'master' into gradient_checkpoint
NikZak Sep 19, 2020
cd09613
Merge branch 'master' of https://github.com/google/automl
NikZak Sep 20, 2020
913b5bf
cleanup
NikZak Sep 20, 2020
d7da5b1
Merge branch 'master' into gradient_checkpoint
NikZak Sep 20, 2020
ed54123
cleanup, comments
NikZak Sep 21, 2020
ffb122e
Merge branch 'master' into gradient_checkpoint
NikZak Sep 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions efficientdet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ We have provided a list of EfficientDet checkpoints and results as follows:

| Model | AP<sup>test</sup> | AP<sub>50</sub> | AP<sub>75</sub> |AP<sub>S</sub> | AP<sub>M</sub> | AP<sub>L</sub> | AP<sup>val</sup> | | #params | #FLOPs |
|---------- |------ |------ |------ | -------- | ------| ------| ------ |------ |------ | :------: |
| EfficientDet-D0 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d0.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d0.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d0_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d0_coco_test-dev2017.txt)) | 34.6 | 53.0 | 37.1 | 12.4 | 39.0 | 52.7 | 34.3 | | 3.9M | 2.54B |
| EfficientDet-D1 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d1.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d1.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d1_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d1_coco_test-dev2017.txt)) | 40.5 | 59.1 | 43.7 | 18.3 | 45.0 | 57.5 | 40.2 | | 6.6M | 6.10B |
| EfficientDet-D2 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d2.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d2.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d2_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d2_coco_test-dev2017.txt)) | 43.0 | 62.3 | 46.2 | 22.5 | 47.0 | 58.4 | 42.5 | | 8.1M | 11.0B |
| EfficientDet-D3 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d3.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d3.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d3_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d3_coco_test-dev2017.txt)) | 47.5 | 66.2 | 51.5 | 27.9 | 51.4 | 62.0 | 47.2 | | 12.0M | 24.9B |
| EfficientDet-D4 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d4.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d4.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d4_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d4_coco_test-dev2017.txt)) | 49.7 | 68.4 | 53.9 | 30.7 | 53.2 | 63.2 | 49.3 | | 20.7M | 55.2B |
| EfficientDet-D5 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d5.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d5.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d5_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d5_coco_test-dev2017.txt)) | 51.5 | 70.5 | 56.1 | 33.9 | 54.7 | 64.1 | 51.2 | | 33.7M | 130B |
| EfficientDet-D6 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d6.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d6.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d6_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d6_coco_test-dev2017.txt)) | 52.6 | 71.5 | 57.2 | 34.9 | 56.0 | 65.4 | 52.1 | | 51.9M | 226B |
| EfficientDet-D7 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d7_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d7_coco_test-dev2017.txt)) | 53.7 | 72.4 | 58.4 | 35.8 | 57.0 | 66.3 | 53.4 | | 51.9M | 325B |
| EfficientDet-D7x ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7x.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7x.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d7x_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d7x_coco_test-dev2017.txt)) | 55.1 | 74.3 | 59.9 | 37.2 | 57.9 | 68.0 | 54.4 | | 77.0M | 410B |
| EfficientDet-D0 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d0.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d0_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d0_coco_test-dev2017.txt)) | 34.6 | 53.0 | 37.1 | 12.4 | 39.0 | 52.7 | 34.3 | | 3.9M | 2.54B |
| EfficientDet-D1 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d1.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d1_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d1_coco_test-dev2017.txt)) | 40.5 | 59.1 | 43.7 | 18.3 | 45.0 | 57.5 | 40.2 | | 6.6M | 6.10B |
| EfficientDet-D2 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d2.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d2_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d2_coco_test-dev2017.txt)) | 43.0 | 62.3 | 46.2 | 22.5 | 47.0 | 58.4 | 42.5 | | 8.1M | 11.0B |
| EfficientDet-D3 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d3.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d3_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d3_coco_test-dev2017.txt)) | 47.5 | 66.2 | 51.5 | 27.9 | 51.4 | 62.0 | 47.2 | | 12.0M | 24.9B |
| EfficientDet-D4 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d4.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d4_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d4_coco_test-dev2017.txt)) | 49.7 | 68.4 | 53.9 | 30.7 | 53.2 | 63.2 | 49.3 | | 20.7M | 55.2B |
| EfficientDet-D5 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d5.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d5_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d5_coco_test-dev2017.txt)) | 51.5 | 70.5 | 56.1 | 33.9 | 54.7 | 64.1 | 51.2 | | 33.7M | 130B |
| EfficientDet-D6 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d6.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d6_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d6_coco_test-dev2017.txt)) | 52.6 | 71.5 | 57.2 | 34.9 | 56.0 | 65.4 | 52.1 | | 51.9M | 226B |
| EfficientDet-D7 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d7_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d7_coco_test-dev2017.txt)) | 53.7 | 72.4 | 58.4 | 35.8 | 57.0 | 66.3 | 53.4 | | 51.9M | 325B |
| EfficientDet-D7x ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7x.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d7x_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d7x_coco_test-dev2017.txt)) | 55.1 | 74.3 | 59.9 | 37.2 | 57.9 | 68.0 | 54.4 | | 77.0M | 410B |

<sup><em>val</em> denotes validation results, <em>test-dev</em> denotes test-dev2017 results. AP<sup>val</sup> is for validation accuracy, all other AP results in the table are for COCO test-dev2017. All accuracy numbers are for single-model single-scale without ensemble or test-time augmentation. EfficientDet-D0 to D6 are trained for 300 epochs and D7/D7x are trained for 600 epochs.</sup>

Expand All @@ -73,11 +73,11 @@ In addition, the following table includes a list of models trained with fixed 64

| Model | mAP | Latency |
| ------ | ------ | ------ |
| D2(640) [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d2-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d2-640.tar.gz) | 41.7 | 14.8ms |
| D3(640) [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d3-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d3-640.tar.gz) | 44.0 | 18.7ms |
| D4(640) [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d4-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d4-640.tar.gz) | 45.7 | 21.7ms |
| D5(640 [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d5-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d5-640.tar.gz) | 46.6 | 26.6ms |
| D6(640) [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d6-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d6-640.tar.gz) | 47.9 | 33.8ms |
| D2(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d2-640.tar.gz) | 41.7 | 14.8ms |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these changes intended?

| D3(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d3-640.tar.gz) | 44.0 | 18.7ms |
| D4(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d4-640.tar.gz) | 45.7 | 21.7ms |
| D5(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d5-640.tar.gz) | 46.6 | 26.6ms |
| D6(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d6-640.tar.gz) | 47.9 | 33.8ms |



Expand Down Expand Up @@ -369,4 +369,26 @@ For more instructions about training on TPUs, please refer to the following tuto

* EfficientNet tutorial: https://cloud.google.com/tpu/docs/tutorials/efficientnet

## 11. Reducing Memory Usage when Training EfficientDets on GPU.

EfficientDets use a lot of GPU memory for a few reasons:

* Large input resolution: because resolution is one of the scaling dimension, our resolution tends to be higher, which significantly increase activations (although no parameter increase).
* Large internal activations for backbone: our backbone uses a relatively large expansion ratio (6), causing the large expanded activations.
* Deep BiFPN: our BiFPN has multiple top-down and bottom-up paths, which leads to a lot of intermediate memory usage during training.

To train this model on GPU with low memory there is an experimental option gradient_checkpointing.

Check these links for a high-level idea of what gradient checkpointing is doing:
1. https://github.com/cybertronai/gradient-checkpointing
2. https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9

**gradient_checkpointing: True**

If set to True, strings defined by gradient_checkpointing_list (["Add"] by default) are searched in the tensors names and any tensors that match a string from the list are kept as checkpoints. When this option is used the standard tensorflow.python.ops.gradients method is being replaced with a custom method.

Testing shows that:
* On d4 network with batch-size of 1 (mixed precision enabled) it takes only 1/3.2 of memory with roughly 32% slower computation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice document!

* It also allows to compute a d6 network with batch size of 2 (mixed precision enabled) on a 11Gb (2080Ti) GPU

NOTE: this is not an official Google product.
88 changes: 84 additions & 4 deletions efficientdet/det_model_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf

import coco_metric
import efficientdet_arch
import hparams_config
Expand Down Expand Up @@ -153,7 +152,7 @@ def focal_loss(y_pred, y_true, alpha, gamma, normalizer, label_smoothing=0.0):
pred_prob = tf.sigmoid(y_pred)
p_t = (y_true * pred_prob) + ((1 - y_true) * (1 - pred_prob))
alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
modulating_factor = (1.0 - p_t) ** gamma
modulating_factor = (1.0 - p_t)**gamma

# apply label smoothing for cross_entropy for each entry.
y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
Expand Down Expand Up @@ -302,8 +301,7 @@ class and box losses from all levels.
box_loss = tf.add_n(box_losses) if box_losses else 0

total_loss = (
cls_loss +
params['box_loss_weight'] * box_loss +
cls_loss + params['box_loss_weight'] * box_loss +
params['iou_loss_weight'] * box_iou_loss)

return total_loss, cls_loss, box_loss, box_iou_loss
Expand Down Expand Up @@ -347,6 +345,7 @@ def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
params['is_training_bn'] = (mode == tf.estimator.ModeKeys.TRAIN)

if params['use_keras_model']:

def model_fn(inputs):
model = efficientdet_keras.EfficientDetNet(
config=hparams_config.Config(params))
Expand Down Expand Up @@ -418,6 +417,23 @@ def model_fn(inputs):

if params['strategy'] == 'tpu':
optimizer = tf.tpu.CrossShardOptimizer(optimizer)
if params['gradient_checkpointing']:
from third_party.grad_checkpoint \
import memory_saving_gradients # pylint: disable=g-import-not-at-top
from tensorflow.python.ops \
import gradients # pylint: disable=g-import-not-at-top
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These imports can probably fit into a single line (try to avoid "").


# monkey patch tf.gradients to point to our custom version,
# with automatic checkpoint selection
def gradients_(ys, xs, grad_ys=None, **kwargs):
return memory_saving_gradients.gradients(
ys,
xs,
grad_ys,
checkpoints=params['gradient_checkpointing_list'],
**kwargs)

gradients.__dict__["gradients"] = gradients_

# Batch norm requires update_ops to be added as a train_op dependency.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
Expand Down Expand Up @@ -615,6 +631,70 @@ def before_run(self, run_context):
every_n_iter=params.get('iterations_per_loop', 100),
)
training_hooks.append(logging_hook)

if params["nvgpu_logging"]:
try:
from third_party.tools import nvgpu # pylint: disable=g-import-not-at-top
from functools import reduce # pylint: disable=g-import-not-at-top
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just import functools, and use functools.reduce


def get_nested_value(d, path):
return reduce(dict.get, path, d)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move most of the code to nvgpu, so this file can be clean? thanks.

For example: nvgpu_gpu_info and commonsize and formatter_log can be moved to nvgpu.


def nvgpu_gpu_info(inp):
inp = inp.decode("utf-8")
inp = inp.split(",")
inp = [x.strip() for x in inp]
value = get_nested_value(nvgpu.gpu_info(), inp)
return np.str(value)

def commonsize(inp):
const_sizes = {
'B': 1,
'KB': 1e3,
'MB': 1e6,
'GB': 1e9,
'TB': 1e12,
'PB': 1e15,
'KiB': 1024,
'MiB': 1048576,
'GiB': 1073741824
}
inp = inp.split(" ")
# convert all to MiB
if inp[1] != 'MiB':
inp_ = float(inp[0]) * (const_sizes[inp[1]] / 1048576.0)
else:
inp_ = float(inp[0])

return inp_

def formatter_log(tensors):
"""Format the output."""
mem_used = tensors["memory used"].decode("utf-8")
mem_total = tensors["memory total"].decode("utf-8")
mem_util = commonsize(mem_used) / commonsize(mem_total)
logstring = (
"GPU memory used: {} = {:.1%} ".format(mem_used, mem_util) +
"of total GPU memory: {}".format(mem_total))
return logstring

mem_used = tf.py_func(nvgpu_gpu_info, ['gpu, fb_memory_usage, used'],
[tf.string])[0]
mem_total = tf.py_func(nvgpu_gpu_info, ['gpu, fb_memory_usage, total'],
[tf.string])[0]

logging_hook3 = tf.estimator.LoggingTensorHook(
tensors={
"memory used": mem_used,
"memory total": mem_total,
},
every_n_iter=params.get('iterations_per_loop', 100),
formatter=formatter_log,
)
training_hooks.append(logging_hook3)
except:
logging.error("nvgpu error: nvidia-smi format not recognized")

if params['strategy'] == 'tpu':
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode,
Expand Down
11 changes: 10 additions & 1 deletion efficientdet/hparams_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def add_kv_recursive(k, v):
return {k: [eval_str_fn(vv) for vv in v.split('*')]}
return {k: eval_str_fn(v)}
pos = k.index('.')
return {k[:pos]: add_kv_recursive(k[pos+1:], v)}
return {k[:pos]: add_kv_recursive(k[pos + 1:], v)}

def merge_dict_recursive(target, src):
"""Recursively merge two nested dictionary."""
Expand All @@ -161,6 +161,8 @@ def as_dict(self):
else:
config_dict[k] = copy.deepcopy(v)
return config_dict


# pylint: enable=protected-access

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can move "# pylint: enable=protected-access" right after return (with same indent), to avoid too many empty lines.


Expand Down Expand Up @@ -281,6 +283,13 @@ def default_detection_configs():
h.dataset_type = None
h.positives_momentum = None

# Reduces memory during training
h.gradient_checkpointing = False
h.gradient_checkpointing_list = ["Add"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment to explain what values can be used other than "Add"?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding more details. Could you explain a little bit more: what's the impact of this list?

If I use ["Add"], does it mean it would automatically checkpoint all "Add" operation?
If I use ['Add', 'Sigmoid'], does it mean it would automatically checkpoint all 'Add' and 'Sigmoid" ops?

If so, what's the pros and cons for adding more ops, and why the default is 'Add'?

Sorry if these questions annoy you, but I am hoping to make it clear as this is a greatly useful feature. Thanks!


# enable memory logging for NVIDIA cards
h.nvgpu_logging = False

return h


Expand Down
Loading