diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 0f245f2a3058f4..19f0250e6f8ce6 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -1414,14 +1414,12 @@ def get_encoder(self): def get_decoder(self): return self.decoder - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.freeze_backbone def freeze_backbone(self): - for name, param in self.backbone.conv_encoder.model.named_parameters(): + for name, param in self.backbone.model.named_parameters(): param.requires_grad_(False) - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.unfreeze_backbone def unfreeze_backbone(self): - for name, param in self.backbone.conv_encoder.model.named_parameters(): + for name, param in self.backbone.model.named_parameters(): param.requires_grad_(True) # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio diff --git a/tests/models/deta/test_modeling_deta.py b/tests/models/deta/test_modeling_deta.py index d5bf32acaba7e0..8581723ccb3b72 100644 --- a/tests/models/deta/test_modeling_deta.py +++ b/tests/models/deta/test_modeling_deta.py @@ -162,6 +162,26 @@ def create_and_check_deta_model(self, config, pixel_values, pixel_mask, labels): self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.hidden_size)) + def create_and_check_deta_freeze_backbone(self, config, pixel_values, pixel_mask, labels): + model = DetaModel(config=config) + model.to(torch_device) + model.eval() + + model.freeze_backbone() + + for _, param in model.backbone.model.named_parameters(): + self.parent.assertEqual(False, param.requires_grad) + + def create_and_check_deta_unfreeze_backbone(self, config, pixel_values, pixel_mask, labels): + model = DetaModel(config=config) + model.to(torch_device) + model.eval() + + model.unfreeze_backbone() + + for _, param in model.backbone.model.named_parameters(): + self.parent.assertEqual(True, param.requires_grad) + def create_and_check_deta_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): model = DetaForObjectDetection(config=config) model.to(torch_device) @@ -250,6 +270,14 @@ def test_deta_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_deta_model(*config_and_inputs) + def test_deta_freeze_backbone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs) + + def test_deta_unfreeze_backbone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs) + def test_deta_object_detection_head_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)