Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update for tf2.4 #908

Merged
merged 8 commits into from
Dec 27, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 88 additions & 41 deletions efficientdet/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def __init__(self, image, output_size):
self._crop_offset_y = tf.constant(0)
self._crop_offset_x = tf.constant(0)

@property
def image(self):
return self._image

@image.setter
def image(self, image):
self._image = image

def normalize_image(self):
"""Normalize the image to zero mean and unit variance."""
# The image normalization is identical to Cloud TPU ResNet.
Expand All @@ -61,6 +69,7 @@ def normalize_image(self):
scale = tf.expand_dims(scale, axis=0)
scale = tf.expand_dims(scale, axis=0)
self._image /= scale
return self._image

def set_training_random_scale_factors(self,
scale_min,
Expand Down Expand Up @@ -126,6 +135,7 @@ def set_scale_factors_to_output_size(self):

def resize_and_crop_image(self, method=tf.image.ResizeMethod.BILINEAR):
"""Resize input image and crop it to the self._output dimension."""
dtype = self._image.dtype
scaled_image = tf.image.resize(
self._image, [self._scaled_height, self._scaled_width], method=method)
scaled_image = scaled_image[self._crop_offset_y:self._crop_offset_y +
Expand All @@ -135,7 +145,8 @@ def resize_and_crop_image(self, method=tf.image.ResizeMethod.BILINEAR):
output_image = tf.image.pad_to_bounding_box(scaled_image, 0, 0,
self._output_size[0],
self._output_size[1])
return output_image
self._image = tf.cast(output_image, dtype)
return self._image


class DetectionInputProcessor(InputProcessor):
Expand Down Expand Up @@ -245,6 +256,70 @@ def __init__(self,
self._max_instances_per_image = max_instances_per_image or 100
self._debug = debug

def _common_image_process(self, image, classes, boxes, data, params):
# Training time preprocessing.
if params['skip_crowd_during_training']:
indices = tf.where(tf.logical_not(data['groundtruth_is_crowd']))
classes = tf.gather_nd(classes, indices)
boxes = tf.gather_nd(boxes, indices)

if params.get('grid_mask', None):
from aug import gridmask # pylint: disable=g-import-not-at-top
image, boxes = gridmask.gridmask(image, boxes)

if params.get('autoaugment_policy', None):
from aug import autoaugment # pylint: disable=g-import-not-at-top
if params['autoaugment_policy'] == 'randaug':
image, boxes = autoaugment.distort_image_with_randaugment(
image, boxes, num_layers=1, magnitude=15)
else:
image, boxes = autoaugment.distort_image_with_autoaugment(
image, boxes, params['autoaugment_policy'])
return image, boxes, classes

def _resize_image_first(self, image, classes, boxes, data, params):
input_processor = DetectionInputProcessor(image, params['image_size'],
boxes, classes)
if self._is_training:
if params['input_rand_hflip']:
input_processor.random_horizontal_flip()

input_processor.set_training_random_scale_factors(
params['jitter_min'], params['jitter_max'],
params.get('target_size', None))
else:
input_processor.set_scale_factors_to_output_size()

image = input_processor.resize_and_crop_image()
boxes, classes = input_processor.resize_and_crop_boxes()

if self._is_training:
image, boxes, classes = self._common_image_process(image, classes, boxes, data, params)

input_processor.image = image
image = input_processor.normalize_image()
return image, boxes, classes, input_processor.image_scale_to_original

def _resize_image_last(self, image, classes, boxes, data, params):
if self._is_training:
image, boxes, classes = self._common_image_process(image, classes, boxes, data, params)

input_processor = DetectionInputProcessor(image, params['image_size'],
boxes, classes)
if self._is_training:
if params['input_rand_hflip']:
input_processor.random_horizontal_flip()

input_processor.set_training_random_scale_factors(
params['jitter_min'], params['jitter_max'],
params.get('target_size', None))
else:
input_processor.set_scale_factors_to_output_size()
input_processor.normalize_image()
image = input_processor.resize_and_crop_image()
boxes, classes = input_processor.resize_and_crop_boxes()
return image, boxes, classes, input_processor.image_scale_to_original

@tf.autograph.experimental.do_not_convert
def dataset_parser(self, value, example_decoder, anchor_labeler, params):
"""Parse data to a fixed dimension input image and learning targets.
Expand Down Expand Up @@ -293,41 +368,14 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
is_crowds = data['groundtruth_is_crowd']
image_masks = data.get('groundtruth_instance_masks', [])
classes = tf.reshape(tf.cast(classes, dtype=tf.float32), [-1, 1])

if self._is_training:
# Training time preprocessing.
if params['skip_crowd_during_training']:
indices = tf.where(tf.logical_not(data['groundtruth_is_crowd']))
classes = tf.gather_nd(classes, indices)
boxes = tf.gather_nd(boxes, indices)

if params.get('grid_mask', None):
from aug import gridmask # pylint: disable=g-import-not-at-top
image, boxes = gridmask.gridmask(image, boxes)

if params.get('autoaugment_policy', None):
from aug import autoaugment # pylint: disable=g-import-not-at-top
if params['autoaugment_policy'] == 'randaug':
image, boxes = autoaugment.distort_image_with_randaugment(
image, boxes, num_layers=1, magnitude=15)
else:
image, boxes = autoaugment.distort_image_with_autoaugment(
image, boxes, params['autoaugment_policy'])

input_processor = DetectionInputProcessor(image, params['image_size'],
boxes, classes)
input_processor.normalize_image()
if self._is_training:
if params['input_rand_hflip']:
input_processor.random_horizontal_flip()

input_processor.set_training_random_scale_factors(
params['jitter_min'], params['jitter_max'],
params.get('target_size', None))
else:
input_processor.set_scale_factors_to_output_size()
image = input_processor.resize_and_crop_image()
boxes, classes = input_processor.resize_and_crop_boxes()
source_area = tf.shape(image)[0] * tf.shape(image)[1]
target_size = utils.parse_image_size(params['image_size'])
target_area = target_size[0] * target_size[1]
# set condition in order to always process small
# first which could speed up pipeline
image, boxes, classes, image_scale = tf.cond(source_area > target_area,
lambda: self._resize_image_first(image, classes, boxes, data, params),
lambda: self._resize_image_last(image, classes, boxes, data, params))

# Assign anchors.
(cls_targets, box_targets,
Expand All @@ -338,7 +386,6 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
source_id = tf.strings.to_number(source_id)

# Pad groundtruth data for evaluation.
image_scale = input_processor.image_scale_to_original
boxes *= image_scale
is_crowds = tf.cast(is_crowds, dtype=tf.float32)
boxes = pad_to_fixed_size(boxes, -1, [self._max_instances_per_image, 4])
Expand All @@ -349,7 +396,7 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
[self._max_instances_per_image, 1])
if params['mixed_precision']:
dtype = (
tf.keras.mixed_precision.experimental.global_policy().compute_dtype)
tf.keras.mixed_precision.global_policy().compute_dtype)
image = tf.cast(image, dtype=dtype)
box_targets = tf.nest.map_structure(
lambda box_target: tf.cast(box_target, dtype=dtype), box_targets)
Expand Down Expand Up @@ -427,7 +474,7 @@ def _prefetch_dataset(filename):
return dataset

dataset = dataset.interleave(
_prefetch_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE)
_prefetch_dataset, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.with_options(self.dataset_options)
if self._is_training:
dataset = dataset.shuffle(64, seed=seed)
Expand All @@ -442,12 +489,12 @@ def _prefetch_dataset(filename):
anchor_labeler, params)
# pylint: enable=g-long-lambda
dataset = dataset.map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
map_fn, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.prefetch(batch_size)
dataset = dataset.batch(batch_size, drop_remainder=params['drop_remainder'])
dataset = dataset.map(
lambda *args: self.process_example(params, batch_size, *args))
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
if self._use_fake_data:
# Turn this dataset into a semi-fake dataset which always loop at the
# first batch. This reduces variance in performance and is useful in
Expand Down
2 changes: 1 addition & 1 deletion efficientdet/det_model_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def model_fn(inputs):

precision = utils.get_precision(params['strategy'], params['mixed_precision'])
cls_outputs, box_outputs = utils.build_model_with_precision(
precision, model_fn, features, params['is_training_bn'])
precision, model_fn, features)

levels = cls_outputs.keys()
for level in levels:
Expand Down
2 changes: 1 addition & 1 deletion efficientdet/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def model_arch(feats, model_name=None, **kwargs):
model_arch = det_model_fn.get_model_arch(model_name)

cls_outputs, box_outputs = utils.build_model_with_precision(
precision, model_arch, inputs, False, model_name, **kwargs)
precision, model_arch, inputs, model_name, **kwargs)

if mixed_precision:
# Post-processing has multiple places with hard-coded float32.
Expand Down
4 changes: 3 additions & 1 deletion efficientdet/keras/efficientdet_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from keras import tfmot
from keras import util_keras
# pylint: disable=arguments-differ # fo keras layers.

utils.BatchNormalization = util_keras.get_batch_norm(tf.keras.layers.BatchNormalization)
utils.SyncBatchNormalization = util_keras.get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)
utils.TpuBatchNormalization = util_keras.get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)

def add_n(nodes):
"""A customized add_n to add up a list of tensors."""
Expand Down
6 changes: 3 additions & 3 deletions efficientdet/keras/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def main(_):
config.override(FLAGS.hparams)

# Use 'mixed_float16' if running on GPUs.
policy = tf.keras.mixed_precision.experimental.Policy('float32')
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.config.experimental_run_functions_eagerly(FLAGS.debug)
policy = tf.keras.mixed_precision.Policy('float32')
tf.keras.mixed_precision.set_global_policy(policy)
tf.config.run_functions_eagerly(FLAGS.debug)

# Create and run the model.
model = efficientdet_keras.EfficientDetModel(config=config)
Expand Down
4 changes: 2 additions & 2 deletions efficientdet/keras/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ def __init__(self,
mixed_precision = self.params.get('mixed_precision', None)
precision = utils.get_precision(
self.params.get('strategy', None), mixed_precision)
policy = tf.keras.mixed_precision.experimental.Policy(precision)
tf.keras.mixed_precision.experimental.set_policy(policy)
policy = tf.keras.mixed_precision.Policy(precision)
tf.keras.mixed_precision.set_global_policy(policy)

@property
def model(self):
Expand Down
6 changes: 3 additions & 3 deletions efficientdet/keras/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def main(_):
tf.config.experimental.set_memory_growth(gpu, True)

if FLAGS.debug:
tf.config.experimental_run_functions_eagerly(True)
tf.config.run_functions_eagerly(True)
tf.debugging.set_log_device_placement(True)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
tf.random.set_seed(FLAGS.tf_random_seed)
Expand Down Expand Up @@ -202,8 +202,8 @@ def main(_):
config.override(params, True)
# set mixed precision policy by keras api.
precision = utils.get_precision(config.strategy, config.mixed_precision)
policy = tf.keras.mixed_precision.experimental.Policy(precision)
tf.keras.mixed_precision.experimental.set_policy(policy)
policy = tf.keras.mixed_precision.Policy(precision)
tf.keras.mixed_precision.set_global_policy(policy)

def get_dataset(is_training, config):
file_pattern = (
Expand Down
9 changes: 4 additions & 5 deletions efficientdet/keras/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,9 @@ def get_optimizer(params):
optimizer, average_decay=moving_average_decay, dynamic_decay=True)
precision = utils.get_precision(params['strategy'], params['mixed_precision'])
if precision == 'mixed_float16' and params['loss_scale']:
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
optimizer,
loss_scale=tf.mixed_precision.experimental.DynamicLossScale(
params['loss_scale']))
initial_scale=params['loss_scale'])
return optimizer


Expand Down Expand Up @@ -777,7 +776,7 @@ def train_step(self, data):
loss_vals['reg_l2_loss'] = reg_l2_loss
total_loss += tf.cast(reg_l2_loss, loss_dtype)
if isinstance(self.optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = self.optimizer.get_scaled_loss(total_loss)
else:
scaled_loss = total_loss
Expand All @@ -787,7 +786,7 @@ def train_step(self, data):
trainable_vars = self._freeze_vars()
scaled_gradients = tape.gradient(scaled_loss, trainable_vars)
if isinstance(self.optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
tf.keras.mixed_precision.LossScaleOptimizer):
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
else:
gradients = scaled_gradients
Expand Down
7 changes: 7 additions & 0 deletions efficientdet/keras/util_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,10 @@ def fp16_to_fp32_nested(input_nested):
else:
return input_nested
return out_tensor_dict

def get_batch_norm(bn_class):
def _wrapper(*args, **kwargs):
if not kwargs.get('name', None):
kwargs['name'] = 'tpu_batch_normalization'
return bn_class(*args, **kwargs)
return _wrapper
2 changes: 1 addition & 1 deletion efficientdet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def run_train_and_eval(e):
if p.exitcode != 0:
return p.exitcode
else:
tf.compat.v1.reset_default_graph()
tf.reset_default_graph()
run_train_and_eval(e)

else:
Expand Down
10 changes: 5 additions & 5 deletions efficientdet/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
lxml>=4.6.1
absl-py>=0.7.1
absl-py>=0.10.0
matplotlib>=3.0.3
numpy>=1.16.4
numpy>=1.19.4
Pillow>=6.0.0
PyYAML>=5.1
six>=1.12.0
tensorflow>=2.3.0
tensorflow-addons>=0.11.2
six>=1.15.0
tensorflow>=2.4.0
tensorflow-addons>=0.12
neural-structured-learning>=1.3.1
tensorflow-model-optimization>=0.5
Cython>=0.29.13
Expand Down
Loading