Skip to content

Commit

Permalink
Allow to non-trainable models with factory function
Browse files Browse the repository at this point in the history
NOTE: it is particularly important that this command also sets the batch-normalization layers to non-trainable, which now seems to be the standard with Tensorflow 2 + Keras, but is not yet handled well by, e.g., the models from `segmentation_models`. Cf. `freeze_model` from `segmentation_models/models/_utils.py` and, e.g., https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute and keras-team/keras#9965.
  • Loading branch information
francescomilano172 committed Feb 7, 2021
1 parent aa3fa3e commit acc0594
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 2 deletions.
6 changes: 6 additions & 0 deletions src/bfseg/cl_models/base_cl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,14 @@ def _build_model(self):
assert (self.run.config['cl_params']['cl_framework']
in ["ewc", "finetune"
]), "Currently, only EWC and fine-tuning are supported."
# NOTE: by default the model is created as trainable. CL frameworks that
# require a fixed, non-trainable network from which to distill the
# information (e.g., in distillation experiments) should create additional
# models by overloading this method and calling `super()._build_model()` in
# the overload.
self.encoder, self.model = create_model(
model_name=self.run.config['network_params']['architecture'],
trainable=True,
**self.run.config['network_params']['model_params'])
self.new_model = keras.Model(
inputs=self.model.input,
Expand Down
3 changes: 2 additions & 1 deletion src/bfseg/cl_models/distillation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def __init__(self, run, root_output_dir):
"Distillation model requires the CL parameter `distillation_type` "
"to be specified.")

super(DistillationModel, self).__init__(run=run, root_output_dir=root_output_dir)
super(DistillationModel, self).__init__(run=run,
root_output_dir=root_output_dir)

self._started_training_new_task = False

Expand Down
17 changes: 16 additions & 1 deletion src/bfseg/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
def create_model(model_name,
image_h,
image_w,
trainable,
log_params_used=True,
**model_params):
r"""Factory function that creates a model with the given parameters.
Expand All @@ -18,6 +19,7 @@ def create_model(model_name,
"fast_scnn", "unet".
image_h (int): Image height.
image_w (int): Image width.
trainable (bool): Whether or not the model should be trainable.
log_params_used (bool): If True, the complete list of parameters used to
instantiate the model is printed.
---
Expand Down Expand Up @@ -91,4 +93,17 @@ def create_model(model_name,

encoder, model = model_fn(**model_params)

return encoder, model
# Optionally set the model as non-trainable.
if (not trainable):
# NOTE: it is particularly important that this command also sets the
# batch-normalization layers to non-trainable, which now seems to be the
# standard with Tensorflow 2 + Keras, but is not yet handled well by, e.g.,
# the models from `segmentation_models`.
# Cf. `freeze_model` from `segmentation_models/models/_utils.py` and, e.g.,
# https://keras.io/getting_started/faq/#whats-the-difference-between-the-
# training-argument-in-call-and-the-trainable-attribute and
# https://github.com/keras-team/keras/pull/9965.
encoder.trainable = False
model.trainable = False

return encoder, model
1 change: 1 addition & 0 deletions src/nyu_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def pretrain_nyu(_run,
_, model = create_model(model_name="fast_scnn",
image_h=image_h,
image_w=image_w,
trainable=True,
num_downsampling_layers=2)

model.compile(
Expand Down

0 comments on commit acc0594

Please sign in to comment.