-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient checkpointing #711
Changes from 76 commits
012f92d
79ff583
7ec607a
0426f58
433167a
a2811ff
6ec564a
213f5ef
5269096
5fb26d1
2029e42
d2e864a
c690eb9
8336138
d74dc82
ef6584a
8ddff72
bea39c1
fc3c31f
1daf75f
a098a3c
7db4091
14bb3e1
7adff15
9cfe955
61f1bad
72b85f9
3dbee2e
dbb2066
4026376
e61e4df
752717b
6b97d6d
9694ffb
871d8dc
a919982
06888d1
7169be1
0da6486
b99215d
36ce273
294e53d
c0dbafa
ce4e7cd
d266e24
7fc4659
f1cdb2f
3babbf3
5d1dcf6
e5bda6c
08dd162
d514146
d0ad430
08ea86d
c346a8b
c9f4ab2
f007130
285dd5b
cb67af3
363dbe7
9f99e43
aff920d
ad5edd0
1b5ca8f
23d63d9
657c877
f768439
9771a8b
6026eae
697b0aa
e3dcadb
6eecfca
5214c15
cd09613
913b5bf
d7da5b1
ed54123
ffb122e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,7 +136,7 @@ def add_kv_recursive(k, v): | |
return {k: [eval_str_fn(vv) for vv in v.split('*')]} | ||
return {k: eval_str_fn(v)} | ||
pos = k.index('.') | ||
return {k[:pos]: add_kv_recursive(k[pos+1:], v)} | ||
return {k[:pos]: add_kv_recursive(k[pos + 1:], v)} | ||
|
||
def merge_dict_recursive(target, src): | ||
"""Recursively merge two nested dictionary.""" | ||
|
@@ -161,7 +161,7 @@ def as_dict(self): | |
else: | ||
config_dict[k] = copy.deepcopy(v) | ||
return config_dict | ||
# pylint: enable=protected-access | ||
# pylint: enable=protected-access | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe you can move "# pylint: enable=protected-access" right after return (with same indent), to avoid too many empty lines. |
||
|
||
def default_detection_configs(): | ||
|
@@ -281,6 +281,16 @@ def default_detection_configs(): | |
h.dataset_type = None | ||
h.positives_momentum = None | ||
|
||
# Reduces memory during training | ||
h.gradient_checkpointing = False | ||
|
||
# Values that could be used "Add", "Mul", "Conv2d", "Floor", "Sigmoid", etc | ||
# or more specific, e.g. "blocks_10/se/conv2d_1" | ||
h.gradient_checkpointing_list = ["Add"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a comment to explain what values can be used other than "Add"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for adding more details. Could you explain a little bit more: what's the impact of this list? If I use ["Add"], does it mean it would automatically checkpoint all "Add" operation? If so, what's the pros and cons for adding more ops, and why the default is 'Add'? Sorry if these questions annoy you, but I am hoping to make it clear as this is a greatly useful feature. Thanks! |
||
|
||
# enable memory logging for NVIDIA cards | ||
h.nvgpu_logging = False | ||
|
||
return h | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,3 +96,39 @@ def gpu_info(): | |
return XmlDictConfig(root) | ||
except FileNotFoundError: | ||
return None | ||
|
||
|
||
def gpu_memory_util_message(): | ||
"""Provide information about GPUs.""" | ||
gpu_info_d = gpu_info() | ||
if gpu_info_d is not None: | ||
mem_used = gpu_info_d['gpu']['fb_memory_usage']['used'] | ||
mem_total = gpu_info_d['gpu']['fb_memory_usage']['total'] | ||
mem_util = commonsize(mem_used) / commonsize(mem_total) | ||
logstring = ("GPU memory used: {} = {:.1%} ".format(mem_used, mem_util) + | ||
"of total GPU memory: {}".format(mem_total)) | ||
return logstring | ||
return None | ||
|
||
|
||
def commonsize(inp): | ||
"""Convert all to MiB.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about a more informative name such as 'input_size'? Similar, you can rename 'inp_' to 'converted_size' or 'output_size' |
||
const_sizes = { | ||
'B': 1, | ||
'KB': 1e3, | ||
'MB': 1e6, | ||
'GB': 1e9, | ||
'TB': 1e12, | ||
'PB': 1e15, | ||
'KiB': 1024, | ||
'MiB': 1048576, | ||
'GiB': 1073741824 | ||
} | ||
inp = inp.split(" ") | ||
# convert all to MiB | ||
if inp[1] != 'MiB': | ||
inp_ = float(inp[0]) * (const_sizes[inp[1]] / 1048576.0) | ||
else: | ||
inp_ = float(inp[0]) | ||
|
||
return inp_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice document!