From acc05943ccb9955086adf00a0e7a27a938878de5 Mon Sep 17 00:00:00 2001 From: Francesco Milano Date: Sun, 7 Feb 2021 10:41:08 +0100 Subject: [PATCH] Allow to non-trainable models with factory function NOTE: it is particularly important that this command also sets the batch-normalization layers to non-trainable, which now seems to be the standard with Tensorflow 2 + Keras, but is not yet handled well by, e.g., the models from `segmentation_models`. Cf. `freeze_model` from `segmentation_models/models/_utils.py` and, e.g., https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute and https://github.com/keras-team/keras/pull/9965. --- src/bfseg/cl_models/base_cl_model.py | 6 ++++++ src/bfseg/cl_models/distillation_model.py | 3 ++- src/bfseg/utils/models.py | 17 ++++++++++++++++- src/nyu_pretraining.py | 1 + 4 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/bfseg/cl_models/base_cl_model.py b/src/bfseg/cl_models/base_cl_model.py index 3016e43d..1d000ada 100644 --- a/src/bfseg/cl_models/base_cl_model.py +++ b/src/bfseg/cl_models/base_cl_model.py @@ -52,8 +52,14 @@ def _build_model(self): assert (self.run.config['cl_params']['cl_framework'] in ["ewc", "finetune" ]), "Currently, only EWC and fine-tuning are supported." + # NOTE: by default the model is created as trainable. CL frameworks that + # require a fixed, non-trainable network from which to distill the + # information (e.g., in distillation experiments) should create additional + # models by overloading this method and calling `super()._build_model()` in + # the overload. self.encoder, self.model = create_model( model_name=self.run.config['network_params']['architecture'], + trainable=True, **self.run.config['network_params']['model_params']) self.new_model = keras.Model( inputs=self.model.input, diff --git a/src/bfseg/cl_models/distillation_model.py b/src/bfseg/cl_models/distillation_model.py index f20fa4b8..a38dbe10 100644 --- a/src/bfseg/cl_models/distillation_model.py +++ b/src/bfseg/cl_models/distillation_model.py @@ -43,7 +43,8 @@ def __init__(self, run, root_output_dir): "Distillation model requires the CL parameter `distillation_type` " "to be specified.") - super(DistillationModel, self).__init__(run=run, root_output_dir=root_output_dir) + super(DistillationModel, self).__init__(run=run, + root_output_dir=root_output_dir) self._started_training_new_task = False diff --git a/src/bfseg/utils/models.py b/src/bfseg/utils/models.py index 9c6f50d4..910311a4 100644 --- a/src/bfseg/utils/models.py +++ b/src/bfseg/utils/models.py @@ -9,6 +9,7 @@ def create_model(model_name, image_h, image_w, + trainable, log_params_used=True, **model_params): r"""Factory function that creates a model with the given parameters. @@ -18,6 +19,7 @@ def create_model(model_name, "fast_scnn", "unet". image_h (int): Image height. image_w (int): Image width. + trainable (bool): Whether or not the model should be trainable. log_params_used (bool): If True, the complete list of parameters used to instantiate the model is printed. --- @@ -91,4 +93,17 @@ def create_model(model_name, encoder, model = model_fn(**model_params) - return encoder, model + # Optionally set the model as non-trainable. + if (not trainable): + # NOTE: it is particularly important that this command also sets the + # batch-normalization layers to non-trainable, which now seems to be the + # standard with Tensorflow 2 + Keras, but is not yet handled well by, e.g., + # the models from `segmentation_models`. + # Cf. `freeze_model` from `segmentation_models/models/_utils.py` and, e.g., + # https://keras.io/getting_started/faq/#whats-the-difference-between-the- + # training-argument-in-call-and-the-trainable-attribute and + # https://github.com/keras-team/keras/pull/9965. + encoder.trainable = False + model.trainable = False + + return encoder, model \ No newline at end of file diff --git a/src/nyu_pretraining.py b/src/nyu_pretraining.py index e864e04e..d3e56303 100644 --- a/src/nyu_pretraining.py +++ b/src/nyu_pretraining.py @@ -38,6 +38,7 @@ def pretrain_nyu(_run, _, model = create_model(model_name="fast_scnn", image_h=image_h, image_w=image_w, + trainable=True, num_downsampling_layers=2) model.compile(