From e7cc1eb3a4eddc45e3c8fbc0cb13e947161d6e2d Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Fri, 11 Feb 2022 13:11:13 +0800 Subject: [PATCH] fix random seed --- efficientdet/tf2/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/efficientdet/tf2/train.py b/efficientdet/tf2/train.py index 162f0ceed..8ca31ce04 100644 --- a/efficientdet/tf2/train.py +++ b/efficientdet/tf2/train.py @@ -101,7 +101,7 @@ def define_flags(): flags.DEFINE_string('model_name', 'efficientdet-d0', 'Model name.') flags.DEFINE_bool('debug', False, 'Enable debug mode') flags.DEFINE_integer( - 'tf_random_seed', 111111, + 'tf_random_seed', None, 'Fixed random seed for deterministic execution across runs for debugging.' ) flags.DEFINE_bool('profile', False, 'Enable profile mode') @@ -163,10 +163,12 @@ def main(_): for gpu in tf.config.list_physical_devices('GPU'): tf.config.experimental.set_memory_growth(gpu, True) + if FLAGS.tf_random_seed: + tf.random.set_seed(FLAGS.tf_random_seed) + if FLAGS.debug: tf.debugging.set_log_device_placement(True) os.environ['TF_DETERMINISTIC_OPS'] = '1' - tf.random.set_seed(FLAGS.tf_random_seed) logging.set_verbosity(logging.DEBUG) if FLAGS.strategy == 'tpu':