From 0ec89ac850e2501216a30b9af30b14a919ee441d Mon Sep 17 00:00:00 2001 From: Awsaf Date: Mon, 16 Oct 2023 14:20:34 +0600 Subject: [PATCH] add: `get_config` in `GCViT` module --- gcvit/models/gcvit.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/gcvit/models/gcvit.py b/gcvit/models/gcvit.py index c2f6d8f..3ca72a9 100644 --- a/gcvit/models/gcvit.py +++ b/gcvit/models/gcvit.py @@ -209,6 +209,29 @@ def build_graph(self, input_shape=(224, 224, 3)): def summary(self, input_shape=(224, 224, 3)): return self.build_graph(input_shape).summary() + def get_config(self): + config = super().get_config() + config.update( + { + "window_size": self.window_size, + "dim": self.dim, + "depths": self.depths, + "num_heads": self.num_heads, + "drop_rate": self.drop_rate, + "mlp_ratio": self.mlp_ratio, + "qkv_bias": self.qkv_bias, + "qk_scale": self.qk_scale, + "attn_drop": self.attn_drop, + "path_drop": self.path_drop, + "layer_scale": self.layer_scale, + "resize_query": self.resize_query, + "global_pool": self.global_pool, + "num_classes": self.num_classes, + "head_act": self.head_act + } + ) + return config + # load standard models def GCViTXXTiny(