Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exclude_from_weight_decay for AdamW and SGDW #2624

Merged
merged 6 commits into from
Jan 3, 2022

Conversation

leondgarse
Copy link
Contributor

Description

Brief Description of the PR:

Fixes # (issue)

Type of change

Checklist:

  • I've properly formatted my code according to the guidelines
    • By running Black + Flake8
    • By running pre-commit hooks
  • This PR addresses an already submitted issue for TensorFlow Addons
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • This PR contains modifications to C++ custom-ops

How Has This Been Tested?

If we open the commented print line 248, then:

import tensorflow as tf
from tensorflow import keras
import tensorflow_addons as tfa
xx = tf.random.uniform((1000, 32, 32, 3))
yy = tf.one_hot(tf.cast(tf.random.uniform((1000,)) * 10, 'int32'), depth=32)
mm = keras.models.Sequential([keras.layers.Input([32, 32, 3]), keras.layers.Flatten(), keras.layers.BatchNormalization(), keras.layers.Dense(32)])
mm.compile(optimizer=tfa.optimizers.AdamW(weight_decay=0.01, exclude_from_weight_decay=['/gamma', '/beta']), loss="categorical_crossentropy")
mm.fit(xx, yy)
# Filtered: batch_normalization/gamma
# Filtered: batch_normalization/beta
# Filtered: batch_normalization/gamma
# Filtered: batch_normalization/beta
# 32/32 [==============================] - 1s 6ms/step - loss: 8.2161

mm.save('aa.h5')
bb = keras.models.load_model('aa.h5')
print(bb.optimizer.exclude_from_weight_decay)
# ['/gamma', '/beta']

If you're adding a bugfix or new feature please describe the tests that you ran to verify your changes:
*

@bot-of-gabrieldemarmiesse

@PhilJd

You are owner of some files modified in this pull request.
Would you kindly review the changes whenever you have the time to?
Thank you very much.

Copy link
Contributor

@PhilJd PhilJd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _do_use_weight_decay(self, var):
"""Whether to use L2 weight decay for `var`."""
if not self._decay_var_list or var.ref() in self._decay_var_list:
if self.exclude_from_weight_decay:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we maybe factor this part out into a function in optimizer/utils.py as exclude_from_weight_decay(variable, exclude_regexes: List[str]) -> bool? This would ensure consistency between LAMB and potential future optimizers with exclude lists for weight decay and reduce nesting.

Also, I'd consider it to be more intuitive to have decay_var_list having a higher priority over exclude_from_weight decay since it's a function argument (i.e., if decay_var_list is specified, always decay those independent of a match in exclude_weight_decay). WDYT?

Copy link
Contributor Author

@leondgarse leondgarse Dec 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I added a function is_variable_excluded_by_regexes(variable, exclude_regexes: List[str]) -> bool in optimizer/utils.py, as it can also be used in lamb.py _do_layer_adaptation. I will change this name if not comfortable for you. It's mostly called as not is_variable_excluded_by_regexes(...), or maybe prefer a function like is_variable_not_excluded_by_regexes? WDYT?
  • I changed _do_use_weight_decay in optimizers/weight_decay_optimizers.py that if self._decay_var_list and var.ref() in self._decay_var_list will return True. Making decay_var_list higher priority than exclude_from_weight_decay.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for factoring this out and making decay_var_list higher priority!

How about renaming is_variable_excluded_by_regexes to is_variable_matched_by_regexes(variable, regexes)? I feel this is a bit easier to parse. I think having the not in front is fine as it keeps the function name simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya, ranamed is_variable_excluded_by_regexes to is_variable_matched_by_regexes, also its param exclude_regexes to just regexes.

return True
return False

def _get_variable_name(self, param_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this or wouldn't re.search find the substring anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this, as this is from lamb.py. I kept a function get_variable_name in optimizers/utils.py but not using it in is_variable_excluded_by_regexes. Maybe some consideration from the author? @junjiek

def __init__(
self,
weight_decay: Union[FloatTensorLike, Callable],
exclude_from_weight_decay: Optional[List[str]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add this to the Args section below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Could you also add a sentence that explains that decay_var_list in minimize takes priority over exclude_from_weight_decay if specified (and also add a corresponding sentence to the documentation to minimize)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated DecoupledWeightDecayExtension __init__, minimize and apply_gradients docs, also extend_with_decoupled_weight_decay doc. Added exclude_from_weight_decay in AdamW and SGDW **kwargs doc.

@leondgarse
Copy link
Contributor Author

Four test cases added:

  • test_exclude_weight_decay_adamw for AdamW exclude_from_weight_decay param test.
  • test_var_list_with_exclude_list_adamw for testing AdamW decay_var_list priority with exclude_from_weight_decay.
  • test_exclude_weight_decay_sgdw and test_var_list_with_exclude_list_sgdw for same testing aim on SGDW.

@PhilJd
Copy link
Contributor

PhilJd commented Jan 3, 2022

Thanks for making these changes, looks good from my side! :)
I'm not a repo owner, so we'll need to wait for a LGTM from them.

Copy link
Member

@seanpmorgan seanpmorgan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @leondgarse for the PR and @PhilJd for the review!

@seanpmorgan seanpmorgan merged commit 37a368a into tensorflow:master Jan 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants