Skip to content

Commit

Permalink
add: get_config in GCViT module
Browse files Browse the repository at this point in the history
  • Loading branch information
awsaf49 committed Oct 16, 2023
1 parent 3166571 commit 0ec89ac
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions gcvit/models/gcvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,29 @@ def build_graph(self, input_shape=(224, 224, 3)):
def summary(self, input_shape=(224, 224, 3)):
return self.build_graph(input_shape).summary()

def get_config(self):
config = super().get_config()
config.update(
{
"window_size": self.window_size,
"dim": self.dim,
"depths": self.depths,
"num_heads": self.num_heads,
"drop_rate": self.drop_rate,
"mlp_ratio": self.mlp_ratio,
"qkv_bias": self.qkv_bias,
"qk_scale": self.qk_scale,
"attn_drop": self.attn_drop,
"path_drop": self.path_drop,
"layer_scale": self.layer_scale,
"resize_query": self.resize_query,
"global_pool": self.global_pool,
"num_classes": self.num_classes,
"head_act": self.head_act
}
)
return config


# load standard models
def GCViTXXTiny(
Expand Down

0 comments on commit 0ec89ac

Please sign in to comment.