From 539ab6554c4b47e319a43da29125714bb4adfe35 Mon Sep 17 00:00:00 2001 From: who who who Date: Sun, 27 Dec 2020 10:50:40 +0800 Subject: [PATCH] update for tf2.4 (#908) * update for tf2.4 * fix mixed precision with recompute gradient * update README * fix multi gpus training * update README * fix LossScaleOptimizer bug * disable steps_per_execution in default * split all reduce --- efficientdet/README.md | 2 +- efficientdet/dataloader.py | 129 ++++++++++++++++------- efficientdet/det_model_fn.py | 2 +- efficientdet/inference.py | 2 +- efficientdet/keras/efficientdet_keras.py | 4 +- efficientdet/keras/infer.py | 6 +- efficientdet/keras/inference.py | 4 +- efficientdet/keras/train.py | 8 +- efficientdet/keras/train_lib.py | 14 +-- efficientdet/keras/util_keras.py | 7 ++ efficientdet/main.py | 39 ++++++- efficientdet/requirements.txt | 10 +- efficientdet/utils.py | 29 ++--- 13 files changed, 170 insertions(+), 86 deletions(-) diff --git a/efficientdet/README.md b/efficientdet/README.md index e81823e6f..6f7c7dd5d 100644 --- a/efficientdet/README.md +++ b/efficientdet/README.md @@ -369,7 +369,7 @@ For more instructions about training on TPUs, please refer to the following tuto * EfficientNet tutorial: https://cloud.google.com/tpu/docs/tutorials/efficientnet -## 11. Reducing Memory Usage when Training EfficientDets on GPU. (The current approach doesn't support mirrored multi GPU or mixed-precision training) +## 11. Reducing Memory Usage when Training EfficientDets on GPU. EfficientDets use a lot of GPU memory for a few reasons: diff --git a/efficientdet/dataloader.py b/efficientdet/dataloader.py index b725f3b00..28d4e5e90 100644 --- a/efficientdet/dataloader.py +++ b/efficientdet/dataloader.py @@ -48,6 +48,14 @@ def __init__(self, image, output_size): self._crop_offset_y = tf.constant(0) self._crop_offset_x = tf.constant(0) + @property + def image(self): + return self._image + + @image.setter + def image(self, image): + self._image = image + def normalize_image(self): """Normalize the image to zero mean and unit variance.""" # The image normalization is identical to Cloud TPU ResNet. @@ -61,6 +69,7 @@ def normalize_image(self): scale = tf.expand_dims(scale, axis=0) scale = tf.expand_dims(scale, axis=0) self._image /= scale + return self._image def set_training_random_scale_factors(self, scale_min, @@ -126,6 +135,7 @@ def set_scale_factors_to_output_size(self): def resize_and_crop_image(self, method=tf.image.ResizeMethod.BILINEAR): """Resize input image and crop it to the self._output dimension.""" + dtype = self._image.dtype scaled_image = tf.image.resize( self._image, [self._scaled_height, self._scaled_width], method=method) scaled_image = scaled_image[self._crop_offset_y:self._crop_offset_y + @@ -135,7 +145,8 @@ def resize_and_crop_image(self, method=tf.image.ResizeMethod.BILINEAR): output_image = tf.image.pad_to_bounding_box(scaled_image, 0, 0, self._output_size[0], self._output_size[1]) - return output_image + self._image = tf.cast(output_image, dtype) + return self._image class DetectionInputProcessor(InputProcessor): @@ -245,6 +256,70 @@ def __init__(self, self._max_instances_per_image = max_instances_per_image or 100 self._debug = debug + def _common_image_process(self, image, classes, boxes, data, params): + # Training time preprocessing. + if params['skip_crowd_during_training']: + indices = tf.where(tf.logical_not(data['groundtruth_is_crowd'])) + classes = tf.gather_nd(classes, indices) + boxes = tf.gather_nd(boxes, indices) + + if params.get('grid_mask', None): + from aug import gridmask # pylint: disable=g-import-not-at-top + image, boxes = gridmask.gridmask(image, boxes) + + if params.get('autoaugment_policy', None): + from aug import autoaugment # pylint: disable=g-import-not-at-top + if params['autoaugment_policy'] == 'randaug': + image, boxes = autoaugment.distort_image_with_randaugment( + image, boxes, num_layers=1, magnitude=15) + else: + image, boxes = autoaugment.distort_image_with_autoaugment( + image, boxes, params['autoaugment_policy']) + return image, boxes, classes + + def _resize_image_first(self, image, classes, boxes, data, params): + input_processor = DetectionInputProcessor(image, params['image_size'], + boxes, classes) + if self._is_training: + if params['input_rand_hflip']: + input_processor.random_horizontal_flip() + + input_processor.set_training_random_scale_factors( + params['jitter_min'], params['jitter_max'], + params.get('target_size', None)) + else: + input_processor.set_scale_factors_to_output_size() + + image = input_processor.resize_and_crop_image() + boxes, classes = input_processor.resize_and_crop_boxes() + + if self._is_training: + image, boxes, classes = self._common_image_process(image, classes, boxes, data, params) + + input_processor.image = image + image = input_processor.normalize_image() + return image, boxes, classes, input_processor.image_scale_to_original + + def _resize_image_last(self, image, classes, boxes, data, params): + if self._is_training: + image, boxes, classes = self._common_image_process(image, classes, boxes, data, params) + + input_processor = DetectionInputProcessor(image, params['image_size'], + boxes, classes) + if self._is_training: + if params['input_rand_hflip']: + input_processor.random_horizontal_flip() + + input_processor.set_training_random_scale_factors( + params['jitter_min'], params['jitter_max'], + params.get('target_size', None)) + else: + input_processor.set_scale_factors_to_output_size() + input_processor.normalize_image() + image = input_processor.resize_and_crop_image() + boxes, classes = input_processor.resize_and_crop_boxes() + return image, boxes, classes, input_processor.image_scale_to_original + @tf.autograph.experimental.do_not_convert def dataset_parser(self, value, example_decoder, anchor_labeler, params): """Parse data to a fixed dimension input image and learning targets. @@ -293,41 +368,14 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params): is_crowds = data['groundtruth_is_crowd'] image_masks = data.get('groundtruth_instance_masks', []) classes = tf.reshape(tf.cast(classes, dtype=tf.float32), [-1, 1]) - - if self._is_training: - # Training time preprocessing. - if params['skip_crowd_during_training']: - indices = tf.where(tf.logical_not(data['groundtruth_is_crowd'])) - classes = tf.gather_nd(classes, indices) - boxes = tf.gather_nd(boxes, indices) - - if params.get('grid_mask', None): - from aug import gridmask # pylint: disable=g-import-not-at-top - image, boxes = gridmask.gridmask(image, boxes) - - if params.get('autoaugment_policy', None): - from aug import autoaugment # pylint: disable=g-import-not-at-top - if params['autoaugment_policy'] == 'randaug': - image, boxes = autoaugment.distort_image_with_randaugment( - image, boxes, num_layers=1, magnitude=15) - else: - image, boxes = autoaugment.distort_image_with_autoaugment( - image, boxes, params['autoaugment_policy']) - - input_processor = DetectionInputProcessor(image, params['image_size'], - boxes, classes) - input_processor.normalize_image() - if self._is_training: - if params['input_rand_hflip']: - input_processor.random_horizontal_flip() - - input_processor.set_training_random_scale_factors( - params['jitter_min'], params['jitter_max'], - params.get('target_size', None)) - else: - input_processor.set_scale_factors_to_output_size() - image = input_processor.resize_and_crop_image() - boxes, classes = input_processor.resize_and_crop_boxes() + source_area = tf.shape(image)[0] * tf.shape(image)[1] + target_size = utils.parse_image_size(params['image_size']) + target_area = target_size[0] * target_size[1] + # set condition in order to always process small + # first which could speed up pipeline + image, boxes, classes, image_scale = tf.cond(source_area > target_area, + lambda: self._resize_image_first(image, classes, boxes, data, params), + lambda: self._resize_image_last(image, classes, boxes, data, params)) # Assign anchors. (cls_targets, box_targets, @@ -338,7 +386,6 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params): source_id = tf.strings.to_number(source_id) # Pad groundtruth data for evaluation. - image_scale = input_processor.image_scale_to_original boxes *= image_scale is_crowds = tf.cast(is_crowds, dtype=tf.float32) boxes = pad_to_fixed_size(boxes, -1, [self._max_instances_per_image, 4]) @@ -349,7 +396,7 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params): [self._max_instances_per_image, 1]) if params['mixed_precision']: dtype = ( - tf.keras.mixed_precision.experimental.global_policy().compute_dtype) + tf.keras.mixed_precision.global_policy().compute_dtype) image = tf.cast(image, dtype=dtype) box_targets = tf.nest.map_structure( lambda box_target: tf.cast(box_target, dtype=dtype), box_targets) @@ -427,7 +474,7 @@ def _prefetch_dataset(filename): return dataset dataset = dataset.interleave( - _prefetch_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE) + _prefetch_dataset, num_parallel_calls=tf.data.AUTOTUNE) dataset = dataset.with_options(self.dataset_options) if self._is_training: dataset = dataset.shuffle(64, seed=seed) @@ -442,12 +489,12 @@ def _prefetch_dataset(filename): anchor_labeler, params) # pylint: enable=g-long-lambda dataset = dataset.map( - map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) + map_fn, num_parallel_calls=tf.data.AUTOTUNE) dataset = dataset.prefetch(batch_size) dataset = dataset.batch(batch_size, drop_remainder=params['drop_remainder']) dataset = dataset.map( lambda *args: self.process_example(params, batch_size, *args)) - dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + dataset = dataset.prefetch(tf.data.AUTOTUNE) if self._use_fake_data: # Turn this dataset into a semi-fake dataset which always loop at the # first batch. This reduces variance in performance and is useful in diff --git a/efficientdet/det_model_fn.py b/efficientdet/det_model_fn.py index 69c93d0aa..dfeb7e23c 100644 --- a/efficientdet/det_model_fn.py +++ b/efficientdet/det_model_fn.py @@ -341,7 +341,7 @@ def model_fn(inputs): precision = utils.get_precision(params['strategy'], params['mixed_precision']) cls_outputs, box_outputs = utils.build_model_with_precision( - precision, model_fn, features, params['is_training_bn']) + precision, model_fn, features) levels = cls_outputs.keys() for level in levels: diff --git a/efficientdet/inference.py b/efficientdet/inference.py index aeafce165..b589c6924 100644 --- a/efficientdet/inference.py +++ b/efficientdet/inference.py @@ -159,7 +159,7 @@ def model_arch(feats, model_name=None, **kwargs): model_arch = det_model_fn.get_model_arch(model_name) cls_outputs, box_outputs = utils.build_model_with_precision( - precision, model_arch, inputs, False, model_name, **kwargs) + precision, model_arch, inputs, model_name, **kwargs) if mixed_precision: # Post-processing has multiple places with hard-coded float32. diff --git a/efficientdet/keras/efficientdet_keras.py b/efficientdet/keras/efficientdet_keras.py index e0779959e..39b3bff89 100644 --- a/efficientdet/keras/efficientdet_keras.py +++ b/efficientdet/keras/efficientdet_keras.py @@ -28,7 +28,9 @@ from keras import tfmot from keras import util_keras # pylint: disable=arguments-differ # fo keras layers. - +utils.BatchNormalization = util_keras.get_batch_norm(tf.keras.layers.BatchNormalization) +utils.SyncBatchNormalization = util_keras.get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization) +utils.TpuBatchNormalization = util_keras.get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization) def add_n(nodes): """A customized add_n to add up a list of tensors.""" diff --git a/efficientdet/keras/infer.py b/efficientdet/keras/infer.py index 15b0a45f2..8b2006daa 100644 --- a/efficientdet/keras/infer.py +++ b/efficientdet/keras/infer.py @@ -53,9 +53,9 @@ def main(_): config.override(FLAGS.hparams) # Use 'mixed_float16' if running on GPUs. - policy = tf.keras.mixed_precision.experimental.Policy('float32') - tf.keras.mixed_precision.experimental.set_policy(policy) - tf.config.experimental_run_functions_eagerly(FLAGS.debug) + policy = tf.keras.mixed_precision.Policy('float32') + tf.keras.mixed_precision.set_global_policy(policy) + tf.config.run_functions_eagerly(FLAGS.debug) # Create and run the model. model = efficientdet_keras.EfficientDetModel(config=config) diff --git a/efficientdet/keras/inference.py b/efficientdet/keras/inference.py index c830c921a..56a528850 100644 --- a/efficientdet/keras/inference.py +++ b/efficientdet/keras/inference.py @@ -179,8 +179,8 @@ def __init__(self, mixed_precision = self.params.get('mixed_precision', None) precision = utils.get_precision( self.params.get('strategy', None), mixed_precision) - policy = tf.keras.mixed_precision.experimental.Policy(precision) - tf.keras.mixed_precision.experimental.set_policy(policy) + policy = tf.keras.mixed_precision.Policy(precision) + tf.keras.mixed_precision.set_global_policy(policy) @property def model(self): diff --git a/efficientdet/keras/train.py b/efficientdet/keras/train.py index 101788b1c..d091f710b 100644 --- a/efficientdet/keras/train.py +++ b/efficientdet/keras/train.py @@ -78,7 +78,7 @@ def define_flags(): flags.DEFINE_integer('batch_size', 64, 'training batch size') flags.DEFINE_integer('eval_samples', 5000, 'The number of samples for ' 'evaluation.') - flags.DEFINE_integer('steps_per_execution', 200, + flags.DEFINE_integer('steps_per_execution', 1, 'Number of steps per training execution.') flags.DEFINE_string( 'train_file_pattern', None, @@ -163,7 +163,7 @@ def main(_): tf.config.experimental.set_memory_growth(gpu, True) if FLAGS.debug: - tf.config.experimental_run_functions_eagerly(True) + tf.config.run_functions_eagerly(True) tf.debugging.set_log_device_placement(True) os.environ['TF_DETERMINISTIC_OPS'] = '1' tf.random.set_seed(FLAGS.tf_random_seed) @@ -202,8 +202,8 @@ def main(_): config.override(params, True) # set mixed precision policy by keras api. precision = utils.get_precision(config.strategy, config.mixed_precision) - policy = tf.keras.mixed_precision.experimental.Policy(precision) - tf.keras.mixed_precision.experimental.set_policy(policy) + policy = tf.keras.mixed_precision.Policy(precision) + tf.keras.mixed_precision.set_global_policy(policy) def get_dataset(is_training, config): file_pattern = ( diff --git a/efficientdet/keras/train_lib.py b/efficientdet/keras/train_lib.py index 6f7d2fb5c..7bbddb8a6 100644 --- a/efficientdet/keras/train_lib.py +++ b/efficientdet/keras/train_lib.py @@ -310,10 +310,9 @@ def get_optimizer(params): optimizer, average_decay=moving_average_decay, dynamic_decay=True) precision = utils.get_precision(params['strategy'], params['mixed_precision']) if precision == 'mixed_float16' and params['loss_scale']: - optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( + optimizer = tf.keras.mixed_precision.LossScaleOptimizer( optimizer, - loss_scale=tf.mixed_precision.experimental.DynamicLossScale( - params['loss_scale'])) + initial_scale=params['loss_scale']) return optimizer @@ -777,17 +776,18 @@ def train_step(self, data): loss_vals['reg_l2_loss'] = reg_l2_loss total_loss += tf.cast(reg_l2_loss, loss_dtype) if isinstance(self.optimizer, - tf.keras.mixed_precision.experimental.LossScaleOptimizer): + tf.keras.mixed_precision.LossScaleOptimizer): scaled_loss = self.optimizer.get_scaled_loss(total_loss) + optimizer = self.optimizer.inner_optimizer else: scaled_loss = total_loss + optimizer = self.optimizer loss_vals['loss'] = total_loss - loss_vals['learning_rate'] = self.optimizer.learning_rate( - self.optimizer.iterations) + loss_vals['learning_rate'] = optimizer.learning_rate(optimizer.iterations) trainable_vars = self._freeze_vars() scaled_gradients = tape.gradient(scaled_loss, trainable_vars) if isinstance(self.optimizer, - tf.keras.mixed_precision.experimental.LossScaleOptimizer): + tf.keras.mixed_precision.LossScaleOptimizer): gradients = self.optimizer.get_unscaled_gradients(scaled_gradients) else: gradients = scaled_gradients diff --git a/efficientdet/keras/util_keras.py b/efficientdet/keras/util_keras.py index 7c7483a56..d2c16477e 100644 --- a/efficientdet/keras/util_keras.py +++ b/efficientdet/keras/util_keras.py @@ -174,3 +174,10 @@ def fp16_to_fp32_nested(input_nested): else: return input_nested return out_tensor_dict + +def get_batch_norm(bn_class): + def _wrapper(*args, **kwargs): + if not kwargs.get('name', None): + kwargs['name'] = 'tpu_batch_normalization' + return bn_class(*args, **kwargs) + return _wrapper \ No newline at end of file diff --git a/efficientdet/main.py b/efficientdet/main.py index 81899104c..d39eeab3c 100644 --- a/efficientdet/main.py +++ b/efficientdet/main.py @@ -19,6 +19,43 @@ from absl import flags from absl import logging import numpy as np + +from tensorflow.python.ops import custom_gradient # pylint:disable=g-direct-tensorflow-import +from tensorflow.python.framework import ops # pylint:disable=g-direct-tensorflow-import + + +def get_variable_by_name(var_name): + """Given a variable name, retrieves a handle on the tensorflow Variable.""" + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + + def _filter_fn(item): + try: + return var_name == item.op.name + except AttributeError: + # Collection items without operation are ignored. + return False + + candidate_vars = list(filter(_filter_fn, global_vars)) + + if len(candidate_vars) >= 1: + # Filter out non-trainable variables. + candidate_vars = [v for v in candidate_vars if v.trainable] + else: + raise ValueError("Unsuccessful at finding variable {}.".format(var_name)) + + if len(candidate_vars) == 1: + return candidate_vars[0] + elif len(candidate_vars) > 1: + raise ValueError( + "Unsuccessful at finding trainable variable {}. " + "Number of candidates: {}. " + "Candidates: {}".format(var_name, len(candidate_vars), candidate_vars)) + else: + # The variable is not trainable. + return None + +custom_gradient.get_variable_by_name = get_variable_by_name import tensorflow.compat.v1 as tf import dataloader @@ -355,7 +392,7 @@ def run_train_and_eval(e): if p.exitcode != 0: return p.exitcode else: - tf.compat.v1.reset_default_graph() + tf.reset_default_graph() run_train_and_eval(e) else: diff --git a/efficientdet/requirements.txt b/efficientdet/requirements.txt index c9b6dace4..9759ddce2 100644 --- a/efficientdet/requirements.txt +++ b/efficientdet/requirements.txt @@ -1,12 +1,12 @@ lxml>=4.6.1 -absl-py>=0.7.1 +absl-py>=0.10.0 matplotlib>=3.0.3 -numpy>=1.16.4 +numpy>=1.19.4 Pillow>=6.0.0 PyYAML>=5.1 -six>=1.12.0 -tensorflow>=2.3.0 -tensorflow-addons>=0.11.2 +six>=1.15.0 +tensorflow>=2.4.0 +tensorflow-addons>=0.12 neural-structured-learning>=1.3.1 tensorflow-model-optimization>=0.5 Cython>=0.29.13 diff --git a/efficientdet/utils.py b/efficientdet/utils.py index bfd17e19e..3f17df063 100644 --- a/efficientdet/utils.py +++ b/efficientdet/utils.py @@ -219,9 +219,9 @@ def _moments(self, inputs, reduction_axes, keep_dims): # Compute variance using: Var[X]= E[X^2] - E[X]^2. shard_square_of_mean = tf.math.square(shard_mean) shard_mean_of_square = shard_variance + shard_square_of_mean - group_mean, group_mean_of_square = ( - replica_context.all_reduce(tf.distribute.ReduceOp.MEAN, - [shard_mean, shard_mean_of_square])) + group_mean = replica_context.all_reduce(tf.distribute.ReduceOp.MEAN, shard_mean) + group_mean_of_square = replica_context.all_reduce(tf.distribute.ReduceOp.MEAN, + shard_mean_of_square) group_variance = group_mean_of_square - tf.math.square(group_mean) return (group_mean, group_variance) else: @@ -255,9 +255,7 @@ def batch_norm_class(is_training, strategy=None): if is_training and strategy == 'tpu': return TpuBatchNormalization elif is_training and strategy == 'gpus': - # TODO(fsx950223): use SyncBatchNorm after TF bug is fixed (incorrect nccl - # all_reduce). See https://github.com/tensorflow/tensorflow/issues/41980 - return BatchNormalization + return SyncBatchNormalization else: return BatchNormalization @@ -551,7 +549,7 @@ def get_precision(strategy: str, mixed_precision: bool = False): if strategy == 'tpu': return 'mixed_bfloat16' - if tf.config.experimental.list_physical_devices('GPU'): + if tf.config.list_physical_devices('GPU'): return 'mixed_float16' # TODO(fsx950223): Fix CPU float16 inference @@ -582,7 +580,7 @@ def _custom_getter(getter, *args, **kwargs): yield varscope -def set_precision_policy(policy_name: Text = None, loss_scale: bool = False): +def set_precision_policy(policy_name: Text = None): """Set precision policy according to the name. Args: @@ -598,15 +596,11 @@ def set_precision_policy(policy_name: Text = None, loss_scale: bool = False): tf.compat.v1.keras.layers.enable_v2_dtype_behavior() # mixed_float16 training is not supported for now, so disable loss_scale. # float32 and mixed_bfloat16 do not need loss scale for training. - if loss_scale: - policy = tf2.keras.mixed_precision.experimental.Policy(policy_name) - else: - policy = tf2.keras.mixed_precision.experimental.Policy( - policy_name, loss_scale=None) - tf2.keras.mixed_precision.experimental.set_policy(policy) + policy = tf2.keras.mixed_precision.Policy(policy_name) + tf2.keras.mixed_precision.set_global_policy(policy) -def build_model_with_precision(pp, mm, ii, tt, *args, **kwargs): +def build_model_with_precision(pp, mm, ii, *args, **kwargs): """Build model with its inputs/params for a specified precision context. This is highly specific to this codebase, and not intended to be general API. @@ -617,7 +611,6 @@ def build_model_with_precision(pp, mm, ii, tt, *args, **kwargs): pp: A string, precision policy name, such as "mixed_float16". mm: A function, for rmodel builder. ii: A tensor, for model inputs. - tt: A bool, If true, it is for training; otherwise, it is for eval. *args: A list of model arguments. **kwargs: A dict, extra model parameters. @@ -629,13 +622,11 @@ def build_model_with_precision(pp, mm, ii, tt, *args, **kwargs): inputs = tf.cast(ii, tf.bfloat16) with tf.tpu.bfloat16_scope(): outputs = mm(inputs, *args, **kwargs) - set_precision_policy('float32') elif pp == 'mixed_float16': - set_precision_policy(pp, loss_scale=tt) + set_precision_policy(pp) inputs = tf.cast(ii, tf.float16) with float16_scope(): outputs = mm(inputs, *args, **kwargs) - set_precision_policy('float32') elif not pp or pp == 'float32': outputs = mm(ii, *args, **kwargs) else: