Skip to content

Commit

Permalink
Add support for excluding vars for finetuning.
Browse files Browse the repository at this point in the history
When finetune from a pretrained checkpoint, the new problem may have
different number of classes, so the class-predict shape would mismatch.

fix #40
Also related to #68
  • Loading branch information
mingxingtan committed Mar 29, 2020
1 parent 0918e40 commit c470de8
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 7 deletions.
7 changes: 5 additions & 2 deletions efficientdet/det_model_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,11 @@ def scaffold_fn():
"""Loads pretrained model through scaffold function."""
logging.info('restore variables from %s', checkpoint)

var_map = utils.get_ckt_var_map(
ckpt_path=checkpoint, ckpt_scope=ckpt_scope, var_scope=var_scope)
var_map = utils.get_ckpt_var_map(
ckpt_path=checkpoint,
ckpt_scope=ckpt_scope,
var_scope=var_scope,
var_exclude_expr=params.get('var_exclude_expr', None))
tf.train.init_from_checkpoint(checkpoint, var_map)

return tf.train.Scaffold()
Expand Down
4 changes: 3 additions & 1 deletion efficientdet/hparams_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def default_detection_configs():

h.lr_decay_method = 'cosine'
h.moving_average_decay = 0.9998
h.ckpt_var_scope = None
h.ckpt_var_scope = None # ckpt variable scope.
h.var_exclude_expr = None # exclude vars when loading pretrained ckpts.

h.backbone_name = 'efficientnet-b1'
h.backbone_config = None

Expand Down
47 changes: 43 additions & 4 deletions efficientdet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# gtype import
from __future__ import print_function

import re
import os
from absl import logging
import numpy as np
Expand Down Expand Up @@ -48,8 +49,20 @@ def get_ema_vars():
return list(set(ema_vars))


def get_ckt_var_map(ckpt_path, ckpt_scope, var_scope):
"""Get a var map for restoring from pretrained checkpoints."""
def get_ckpt_var_map(ckpt_path, ckpt_scope, var_scope, var_exclude_expr=None):
"""Get a var map for restoring from pretrained checkpoints.
Args:
ckpt_path: string. A pretrained checkpoint path.
ckpt_scope: string. Scope name for checkpoint variables.
var_scope: string. Scope name for model variables.
var_exclude_expr: string. A regex for excluding variables.
This is useful for finetuning with different classes, where
var_exclude_expr='.*class-predict.*' can be used.
Returns:
var_map: a dictionary from checkpoint name to model variables.
"""
logging.info('Init model from checkpoint {}'.format(ckpt_path))
if not ckpt_scope.endswith('/') or not var_scope.endswith('/'):
raise ValueError('Please specific scope name ending with /')
Expand All @@ -63,7 +76,14 @@ def get_ckt_var_map(ckpt_path, ckpt_scope, var_scope):
model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=var_scope)
reader = tf.train.load_checkpoint(ckpt_path)
ckpt_var_names = set(reader.get_variable_to_shape_map().keys())

exclude_matcher = re.compile(var_exclude_expr) if var_exclude_expr else None
for v in model_vars:
if exclude_matcher and exclude_matcher.match(v.op.name):
logging.info(
'skip {} -- excluded by {}'.format(v.op.name, var_exclude_expr))
continue

if not v.op.name.startswith(var_scope):
logging.info('skip {} -- does not match scope {}'.format(
v.op.name, var_scope))
Expand All @@ -74,13 +94,26 @@ def get_ckt_var_map(ckpt_path, ckpt_scope, var_scope):
if ckpt_var not in ckpt_var_names:
logging.info('skip {} ({}) -- not in ckpt'.format(v.op.name, ckpt_var))
continue

logging.info('Init {} from ckpt var {}'.format(v.op.name, ckpt_var))
var_map[ckpt_var] = v
return var_map


def get_ckt_var_map_ema(ckpt_path, ckpt_scope, var_scope):
"""Get a ema var map for restoring from pretrained checkpoints."""
def get_ckpt_var_map_ema(ckpt_path, ckpt_scope, var_scope, var_exclude_expr):
"""Get a ema var map for restoring from pretrained checkpoints.
Args:
ckpt_path: string. A pretrained checkpoint path.
ckpt_scope: string. Scope name for checkpoint variables.
var_scope: string. Scope name for model variables.
var_exclude_expr: string. A regex for excluding variables.
This is useful for finetuning with different classes, where
var_exclude_expr='.*class-predict.*' can be used.
Returns:
var_map: a dictionary from checkpoint name to model variables.
"""
logging.info('Init model from checkpoint {}'.format(ckpt_path))
if not ckpt_scope.endswith('/') or not var_scope.endswith('/'):
raise ValueError('Please specific scope name ending with /')
Expand All @@ -94,7 +127,13 @@ def get_ckt_var_map_ema(ckpt_path, ckpt_scope, var_scope):
model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=var_scope)
reader = tf.train.load_checkpoint(ckpt_path)
ckpt_var_names = set(reader.get_variable_to_shape_map().keys())
exclude_matcher = re.compile(var_exclude_expr) if var_exclude_expr else None
for v in model_vars:
if exclude_matcher and exclude_matcher.match(v.op.name):
logging.info(
'skip {} -- excluded by {}'.format(v.op.name, var_exclude_expr))
continue

if not v.op.name.startswith(var_scope):
logging.info('skip {} -- does not match scope {}'.format(
v.op.name, var_scope))
Expand Down

0 comments on commit c470de8

Please sign in to comment.