-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update for tf2.4 #908
update for tf2.4 #908
Conversation
efficientdet/keras/train.py
Outdated
@@ -74,7 +74,7 @@ | |||
flags.DEFINE_integer('batch_size', 64, 'training batch size') | |||
flags.DEFINE_integer('eval_samples', 5000, 'The number of samples for ' | |||
'evaluation.') | |||
flags.DEFINE_integer('steps_per_execution', 1000, | |||
flags.DEFINE_integer('steps_per_execution', 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disable it in default since there are some issues in multi gpus training with uninitlized optimizer.
efficientdet/dataloader.py
Outdated
input_processor.set_scale_factors_to_output_size() | ||
|
||
image = input_processor.resize_and_crop_image() | ||
boxes, classes = input_processor.resize_and_crop_boxes() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resize image first could double speed up pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting finding!
efficientdet/dataloader.py
Outdated
input_processor.set_scale_factors_to_output_size() | ||
|
||
image = input_processor.resize_and_crop_image() | ||
boxes, classes = input_processor.resize_and_crop_boxes() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting finding!
# TODO(fsx950223): use SyncBatchNorm after TF bug is fixed (incorrect nccl | ||
# all_reduce). See https://github.com/tensorflow/tensorflow/issues/41980 | ||
return BatchNormalization | ||
return SyncBatchNormalization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about the speed of SyncBatchNormalization for multiple GPUs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's slower about 40% than BatchNormalization for multiple GPUs. I believe it's acceptable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to https://github.com/tensorflow/tensorflow/blob/9489702e35b16a40a1accf3b8b5ed557efae10c7/tensorflow/python/keras/layers/normalization_v2.py#L151.
Should I split replica_ctx.all_reduce?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not quite understand these comments, but I think 40% slower is fine.
@@ -622,13 +622,11 @@ def build_model_with_precision(pp, mm, ii, *args, **kwargs): | |||
inputs = tf.cast(ii, tf.bfloat16) | |||
with tf.tpu.bfloat16_scope(): | |||
outputs = mm(inputs, *args, **kwargs) | |||
set_precision_policy('float32') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After remove 2 lines, I could train estimator model with recompute_grad and mixed_precision.
Why set policy back to float32? Could I remove them? @mingxingtan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, please feel free to remove it. it is not necessary.
update for tf2.4 (google#908)
* update for tf2.4 * fix mixed precision with recompute gradient * update README * fix multi gpus training * update README * fix LossScaleOptimizer bug * disable steps_per_execution in default * split all reduce
cc @mingxingtan
I met some issues when I train the model with keras multi gpus on tf2.4, could you check it?
I wonder it's a bug of my environment or tensorflow.