diff --git a/efficientnetv2/utils.py b/efficientnetv2/utils.py index 3727ba1ca..039ca65a1 100644 --- a/efficientnetv2/utils.py +++ b/efficientnetv2/utils.py @@ -411,7 +411,7 @@ def set_precision_policy(policy_name=None): logging.info('use mixed precision policy name %s', policy_name) tf.compat.v1.keras.layers.enable_v2_dtype_behavior() policy = tf.keras.mixed_precision.Policy(policy_name) - tf.keras.mixed_precision.set_policy(policy) + tf.keras.mixed_precision.set_global_policy(policy) def build_model_with_precision(pp, mm, ii, tt, *args, **kwargs):