diff --git a/efficientdet/main.py b/efficientdet/main.py index e78a70585..3957b7ba9 100644 --- a/efficientdet/main.py +++ b/efficientdet/main.py @@ -79,7 +79,7 @@ 'evaluation.') flags.DEFINE_integer('iterations_per_loop', 100, 'Number of iterations per TPU training loop') -flags.DEFINE_integer('save_checkpoints_steps', 5000, +flags.DEFINE_integer('save_checkpoints_steps', 100, 'Number of iterations per checkpoint save') flags.DEFINE_string( 'training_file_pattern', None, @@ -220,7 +220,6 @@ def _can_partition(spatial_dim): FLAGS.iterations_per_loop if FLAGS.strategy == 'tpu' else 1, num_cores_per_replica=num_cores_per_replica, input_partition_dims=input_partition_dims, - save_checkpoints_steps=FLAGS.save_checkpoints_steps, per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig .PER_HOST_V2) @@ -232,6 +231,7 @@ def _can_partition(spatial_dim): log_step_count_steps=FLAGS.iterations_per_loop, session_config=config_proto, tpu_config=tpu_config, + save_checkpoints_steps=FLAGS.save_checkpoints_steps, tf_random_seed=FLAGS.tf_random_seed, ) else: