Skip to content

Commit

Permalink
fix apis
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 committed Apr 14, 2022
1 parent 59b7f5b commit 8a1c443
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion efficientdet/tf2/infer_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_infer_lib_mixed_precision(self):
model_params={'mixed_precision': True})
images = tf.ones((1, 512, 512, 3))
boxes, scores, classes, valid_lens = driver.serve(images)
policy = tf.keras.mixed_precision.experimental.global_policy()
policy = tf.keras.mixed_precision.global_policy()
if policy.name == 'float32':
self.assertEqual(tf.reduce_mean(boxes), 163.09)
self.assertEqual(tf.reduce_mean(scores), 0.01000005)
Expand Down
5 changes: 5 additions & 0 deletions efficientnetv2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ def build_model(in_images):
utils.scalar('train/lr', learning_rate)
optimizer = utils.build_optimizer(
learning_rate, optimizer_name=config.train.optimizer)

if config.runtime.mixed_precision and precision=='mixed_float16':
# Wrap optimizer with loss scale when precision is mixed_float16
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)

if FLAGS.use_tpu:
# When using TPU, wrap the optimizer with CrossShardOptimizer which
# handles synchronization details between different TPU cores. To the
Expand Down
15 changes: 5 additions & 10 deletions efficientnetv2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def _custom_getter(getter, *args, **kwargs):
yield varscope


def set_precision_policy(policy_name=None, loss_scale=False):
def set_precision_policy(policy_name=None):
"""Set precision policy according to the name.
Args:
Expand All @@ -410,14 +410,8 @@ def set_precision_policy(policy_name=None, loss_scale=False):
assert policy_name in ('mixed_float16', 'mixed_bfloat16', 'float32')
logging.info('use mixed precision policy name %s', policy_name)
tf.compat.v1.keras.layers.enable_v2_dtype_behavior()
# mixed_float16 training is not supported for now, so disable loss_scale.
# float32 and mixed_bfloat16 do not need loss scale for training.
if loss_scale:
policy = tf.keras.mixed_precision.experimental.Policy(policy_name)
else:
policy = tf.keras.mixed_precision.experimental.Policy(
policy_name, loss_scale=None)
tf.keras.mixed_precision.experimental.set_policy(policy)
policy = tf.keras.mixed_precision.Policy(policy_name)
tf.keras.mixed_precision.set_policy(policy)


def build_model_with_precision(pp, mm, ii, tt, *args, **kwargs):
Expand All @@ -438,14 +432,15 @@ def build_model_with_precision(pp, mm, ii, tt, *args, **kwargs):
Returns:
the output of mm model.
"""
del tt
if pp == 'mixed_bfloat16':
set_precision_policy(pp)
inputs = tf.cast(ii, tf.bfloat16)
with tf.compat.v1.tpu.bfloat16_scope():
outputs = mm(inputs, *args, **kwargs)
set_precision_policy('float32')
elif pp == 'mixed_float16':
set_precision_policy(pp, loss_scale=tt)
set_precision_policy(pp)
inputs = tf.cast(ii, tf.float16)
with float16_scope():
outputs = mm(inputs, *args, **kwargs)
Expand Down

0 comments on commit 8a1c443

Please sign in to comment.