Skip to content

Commit

Permalink
A few minor fixes for model call.
Browse files Browse the repository at this point in the history
  • Loading branch information
mingxingtan committed Jun 10, 2021
1 parent 94f47cd commit f2b4480
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 16 deletions.
2 changes: 1 addition & 1 deletion efficientnetv2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def build_model(in_images):
"""Build model using the model_name given through the command line."""
config.model.num_classes = config.data.num_classes
model = effnetv2_model.EffNetV2Model(config.model.model_name, config.model)
logits = model(in_images, training=is_training)[0]
logits = model(in_images, training=is_training)
return logits

pre_num_params, pre_num_flops = utils.num_params_flops(readable_format=True)
Expand Down
8 changes: 4 additions & 4 deletions efficientnetv2/main_tf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def train_step(self, data):
images, labels = features['image'], labels['label']

with tf.GradientTape() as tape:
pred = self(images, training=True)[0]
pred = self(images, training=True)
pred = tf.cast(pred, tf.float32)
loss = self.compiled_loss(
labels,
Expand All @@ -105,7 +105,7 @@ def train_step(self, data):
def test_step(self, data):
features, labels = data
images, labels = features['image'], labels['label']
pred = self(images, training=False)[0]
pred = self(images, training=False)
pred = tf.cast(pred, tf.float32)

self.compiled_loss(
Expand Down Expand Up @@ -174,9 +174,9 @@ def main(_) -> None:
weight_decay=config.train.weight_decay)

if config.train.ft_init_ckpt: # load pretrained ckpt for finetuning.
model(tf.ones([1, 224, 224, 3]))
model(tf.keras.Input([None, None, 3]))
ckpt = config.train.ft_init_ckpt
utils.restore_tf2_ckpt(model, ckpt, exclude_layers=('_head', 'optimizer'))
utils.restore_tf2_ckpt(model, ckpt, exclude_layers=('_fc', 'optimizer'))

steps_per_epoch = num_train_images // config.train.batch_size
total_steps = steps_per_epoch * config.train.epochs
Expand Down
11 changes: 0 additions & 11 deletions efficientnetv2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,6 @@ def _moments(self, inputs, reduction_axes, keep_dims):

def call(self, inputs, training=None):
outputs = super().call(inputs, training)
# A temporary hack for tf1 compatibility with keras batch norm.
for u in self.updates:
tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, u)
return outputs


Expand All @@ -217,14 +214,6 @@ def __init__(self, **kwargs):
kwargs['name'] = 'tpu_batch_normalization'
super().__init__(**kwargs)

def call(self, inputs, training=None):
outputs = super().call(inputs, training)
if training and not tf.executing_eagerly():
# A temporary hack for tf1 compatibility with keras batch norm.
for u in self.updates:
tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, u)
return outputs


def normalization(norm_type: str,
axis=-1,
Expand Down

0 comments on commit f2b4480

Please sign in to comment.